Keras UpSampling2D 不一致的行为

这是我的模型:

filters = 256kernel_size = 3strides = 1factor = 4  # 上采样的因子inputLayer = Input(shape=(img_height//factor, img_width//factor, img_depth))conv1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(inputLayer)res = Conv2D(filters, kernel_size, strides=strides, padding='same')(conv1)act = ReLU()(res)res = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)res_rec = Add()([conv1, res])for i in range(15):  # 16-1    res1 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)    act = ReLU()(res1)    res2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(act)    res_rec = Add()([res_rec, res2])conv2 = Conv2D(filters, kernel_size, strides=strides, padding='same')(res_rec)a = Add()([conv1, conv2])up = UpSampling2D(size=4)(a)outputLayer = Conv2D(filters=3,                     kernel_size=1,                     strides=1,                     padding='same')(up)model = Model(inputs=inputLayer, outputs=outputLayer)

model.summary()显示如下:

__________________________________________________________________________________________________Layer (type)                    Output Shape         Param #     Connected to                     ==================================================================================================input_1 (InputLayer)            (None, 350, 350, 3)  0                                            __________________________________________________________________________________________________conv2d_1 (Conv2D)               (None, 350, 350, 256 7168        input_1[0][0]                    __________________________________________________________________________________________________conv2d_2 (Conv2D)               (None, 350, 350, 256 590080      conv2d_1[0][0]                   __________________________________________________________________________________________________re_lu_1 (ReLU)                  (None, 350, 350, 256 0           conv2d_2[0][0]                   __________________________________________________________________________________________________conv2d_3 (Conv2D)               (None, 350, 350, 256 590080      re_lu_1[0][0]                    __________________________________________________________________________________________________add_1 (Add)                     (None, 350, 350, 256 0           conv2d_1[0][0]                                                                                    conv2d_3[0][0]                   __________________________________________________________________________________________________conv2d_4 (Conv2D)               (None, 350, 350, 256 590080      add_1[0][0]                      __________________________________________________________________________________________________re_lu_2 (ReLU)                  (None, 350, 350, 256 0           conv2d_4[0][0]                   __________________________________________________________________________________________________conv2d_5 (Conv2D)               (None, 350, 350, 256 590080      re_lu_2[0][0]                    __________________________________________________________________________________________________add_2 (Add)                     (None, 350, 350, 256 0           add_1[0][0]                                                                                       conv2d_5[0][0]                   __________________________________________________________________________________________________conv2d_6 (Conv2D)               (None, 350, 350, 256 590080      add_2[0][0]                      __________________________________________________________________________________________________re_lu_3 (ReLU)                  (None, 350, 350, 256 0           conv2d_6[0][0]                   __________________________________________________________________________________________________conv2d_7 (Conv2D)               (None, 350, 350, 256 590080      re_lu_3[0][0]                    __________________________________________________________________________________________________add_3 (Add)                     (None, 350, 350, 256 0           add_2[0][0]                                                                                       conv2d_7[0][0]                    ...... this goes on for a long time ..... __________________________________________add_15 (Add)                    (None, 350, 350, 256 0           add_14[0][0]                                                                                      conv2d_31[0][0]                  __________________________________________________________________________________________________conv2d_32 (Conv2D)              (None, 350, 350, 256 590080      add_15[0][0]                     __________________________________________________________________________________________________re_lu_16 (ReLU)                 (None, 350, 350, 256 0           conv2d_32[0][0]                  __________________________________________________________________________________________________conv2d_33 (Conv2D)              (None, 350, 350, 256 590080      re_lu_16[0][0]                   __________________________________________________________________________________________________add_16 (Add)                    (None, 350, 350, 256 0           add_15[0][0]                                                                                      conv2d_33[0][0]                  __________________________________________________________________________________________________conv2d_34 (Conv2D)              (None, 350, 350, 256 590080      add_16[0][0]                     __________________________________________________________________________________________________add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                                                                                    conv2d_34[0][0]                  __________________________________________________________________________________________________up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     __________________________________________________________________________________________________conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            ==================================================================================================Total params: 19,480,579Trainable params: 19,480,579Non-trainable params: 0__________________________________________________________________________________________________None

重要部分就在结尾,靠近输出处:

__________________________________________________________________________________________________add_17 (Add)                    (None, 350, 350, 256 0           conv2d_1[0][0]                                                                                    conv2d_34[0][0]                  __________________________________________________________________________________________________up_sampling2d_1 (UpSampling2D)  (None, 1400, 1400, 2 0           add_17[0][0]                     __________________________________________________________________________________________________conv2d_35 (Conv2D)              (None, 1400, 1400, 3 771         up_sampling2d_1[0][0]            ==================================================================================================

现在,看看我运行网络时得到的错误:

Traceback (most recent call last):  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 280, in <module>    setUpImages()  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 96, in setUpImages    setUpData(trainData, testData)  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 135, in setUpData    setUpModel(X_train, Y_train, validateTestData, trainingTestData)  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 176, in setUpModel    train(model, X_train, Y_train, validateTestData, trainingTestData)  File "C:/Users/payne/PycharmProjects/PixelEnhancer/trainTest.py", line 192, in train    batch_size=32)  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 950, in fit    batch_size=batch_size)  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training.py", line 787, in _standardize_user_data    exception_prefix='target')  File "C:\Users\payne\Anaconda3\envs\ml-gpu\lib\site-packages\keras\engine\training_utils.py", line 137, in standardize_input_data    str(data_shape))ValueError: Error when checking target: expected conv2d_35 to have shape (1400, 1400, 1) but got array with shape (1400, 1400, 3)

为什么我的最后一个卷积层期望得到一个形状为(1400, 1400, 1)的张量,但却得到了形状为(1400, 1400, 3)的张量,而摘要中说UpSampling2D应该返回一个形状为(1400, 1400, 2)的张量?

为了澄清一下背景:这应该是一个网络,它接收一个350x350x3的图像,并输出一个1400x1400x3的图像。


回答:

显然,错误信息并不是专门针对conv2d_35实体的,而是与我的损失函数链接的网络的最后一个实体有关的。

由于我选择了sparse_categorical_crossentropy作为损失函数,它期望得到一个单一维度的向量。

将损失函数设置为mean_squared_error解决了这个问题。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注