我是tensorflow的新手,正在研究一个模型,该模型用于给灰度图像上色。在模型的最后部分,论文中提到:
一旦特征融合后,它们将通过一组卷积和上采样层进行处理,后者仅使用最近邻技术对输入进行上采样,使输出宽度和高度增加一倍。
当我在tensorflow中尝试实现时,我使用了tf.image.resize_nearest_neighbor
来进行上采样,但使用后我发现除了第二轮外,成本在所有轮次中都没有变化,而不使用它时,成本会得到优化和变化。
这部分代码
def Model(Input_images): #一些代码直到以下最后一部分 Color_weights = {'W_conv1':tf.Variable(tf.random_normal([3,3,256,128])),'W_conv2':tf.Variable(tf.random_normal([3,3,128,64])), 'W_conv3':tf.Variable(tf.random_normal([3,3,64,64])), 'W_conv4':tf.Variable(tf.random_normal([3,3,64,32])),'W_conv5':tf.Variable(tf.random_normal([3,3,32,2]))} Color_biases = {'b_conv1':tf.Variable(tf.random_normal([128])),'b_conv2':tf.Variable(tf.random_normal([64])),'b_conv3':tf.Variable(tf.random_normal([64])), 'b_conv4':tf.Variable(tf.random_normal([32])),'b_conv5':tf.Variable(tf.random_normal([2]))} Color_layer1 = tf.nn.relu(Conv2d(Fuse, Color_weights['W_conv1'], 1) + Color_biases['b_conv1']) Color_layer1_up = tf.image.resize_nearest_neighbor(Color_layer1,[56,56]) Color_layer2 = tf.nn.relu(Conv2d(Color_layer1_up, Color_weights['W_conv2'], 1) + Color_biases['b_conv2']) Color_layer3 = tf.nn.relu(Conv2d(Color_layer2, Color_weights['W_conv3'], 1) + Color_biases['b_conv3']) Color_layer3_up = tf.image.resize_nearest_neighbor(Color_layer3,[112,112]) Color_layer4 = tf.nn.relu(Conv2d(Color_layer3, Color_weights['W_conv4'], 1) + Color_biases['b_conv4']) return Color_layer4
训练代码
Prediction = Model(Input_images) Colorization_MSE = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(Prediction,tf.Variable(tf.random_normal([2,112,112,32])))) Optmizer = tf.train.AdadeltaOptimizer(learning_rate= 0.05).minimize(Colorization_MSE) sess = tf.InteractiveSession() sess.run(tf.global_variables_initializer()) for epoch in range(EpochsNum): epoch_loss = 0 Batch_indx = 1 for i in range(int(ExamplesNum / Batch_size)):#Over batches print("Batch Num ",i + 1) ReadNextBatch() a, c = sess.run([Optmizer,Colorization_MSE],feed_dict={Input_images:Batch_GreyImages}) epoch_loss += c print("epoch: ",epoch + 1, ",Los: ",epoch_loss)
那么我的逻辑有什么问题,或者如果问题出在tf.image.resize_nearest_neighbor
上,我应该怎么做,或者有什么替代方案?
回答:
好的,我解决了这个问题,我注意到tf.random_normal
是问题所在,当我用tf.truncated_normal
替换它时,效果很好