Tensorflow准确率达到0.99,但预测结果很差

也许我的预测方法有问题?

这是我的项目… 我有一个灰度输入图像,我试图对其进行分割。分割是一个简单的二元分类(可以理解为前景与背景)。所以真实标签(y)是一个由0和1组成的矩阵——因此有两种分类。哦,输入图像是一个正方形,所以我只使用了一个名为n_input的变量

我的准确率基本上收敛到了0.99,但当我进行预测时,我得到的全是零。编辑 –> 在每个输出矩阵中都有一个1,位置相同…

这是我的会话代码(其他一切都正常工作)…

with tf.Session() as sess:    sess.run(init)    summary = tf.train.SummaryWriter('/tmp/logdir/', sess.graph_def)    step = 1    from tensorflow.contrib.learn.python.learn.datasets.scroll import scroll_data    data = scroll_data.read_data('/home/kendall/Desktop/')    # 继续训练直到达到最大迭代次数    flag = 0    # while flag == 0:    while step * batch_size < training_iters:        batch_y, batch_x = data.train.next_batch(batch_size)        # pdb.set_trace()        # batch_x = batch_x.reshape((batch_size, n_input))        batch_x = batch_x.reshape((batch_size, n_input, n_input))        batch_y = batch_y.reshape((batch_size, n_input, n_input))        batch_y = convert_to_2_channel(batch_y, batch_size)        # batch_y = batch_y.reshape((batch_size, n_output, n_classes))        batch_y = batch_y.reshape((batch_size, 200, 200, n_classes))        sess.run(optimizer, feed_dict={x: batch_x, y: batch_y,                                       keep_prob: dropout})        if step % display_step == 0:            flag = 1            # 计算批次损失和准确率            loss, acc = sess.run([cost, accuracy], feed_dict={x: batch_x,                                                              y: batch_y,                                                              keep_prob: 1.})            print "Iter " + str(step*batch_size) + ", Minibatch Loss= " + \                  "{:.6f}".format(loss) + ", Training Accuracy= " + \                  "{:.5f}".format(acc)        step += 1    print "Optimization Finished!"    save_path = "model.ckpt"    saver.save(sess, save_path)    im = Image.open('/home/kendall/Desktop/HA900_frames/frame0635.tif')    batch_x = np.array(im)    pdb.set_trace()    batch_x = batch_x.reshape((1, n_input, n_input))    batch_x = batch_x.astype(float)    # pdb.set_trace()    prediction = sess.run(pred, feed_dict={x: batch_x, keep_prob: 1.})    print prediction    arr1 = np.empty((n_input,n_input))    arr2 = np.empty((n_input,n_input))    for i in xrange(n_input):        for j in xrange(n_input):            for k in xrange(2):                if k == 0:                    arr1[i][j] = prediction[0][i][j][k]                else:                    arr2[i][j] = prediction[0][i][j][k]    # prediction = np.asarray(prediction)    # prediction = np.reshape(prediction, (200,200))    # np.savetxt("prediction.csv", prediction, delimiter=",")    np.savetxt("prediction1.csv", arr1, delimiter=",")    np.savetxt("prediction2.csv", arr2, delimiter=",")

由于有两种分类,最后那部分(带有几个循环的部分)只是为了将预测结果分成两个2×2的矩阵。

我将预测数组保存到了CSV文件中,就像我说的,它们全是零。

我还确认了所有数据都是正确的(维度和值)。

为什么训练会收敛,但预测结果很差呢?

如果你想查看所有代码,这里是…

...

回答:

代码中的错误

你的代码中有多个错误:

警告:此操作期望未缩放的logits,因为它在内部对logits执行softmax以提高效率。不要用softmax的输出调用此操作,因为它会产生错误的结果。

  • 事实上,由于你有两个类,你应该使用带有softmax的损失,使用tf.nn.softmax_cross_entropy_with_logits

  • 当使用tf.argmax(pred, 1)时,你只在轴1上应用argmax,这是输出图像的高度。你应该在最后一个轴(大小为2)上使用tf.argmax(pred, 3)

    • 这可能解释了为什么你得到0.99的准确率
    • 在输出图像上,它将在图像的高度上进行argmax,默认情况下是0(因为每个通道的所有值都相等)

模型错误

最大的缺点是你的模型总体上将非常难优化。

  • 你对40,000个类进行了softmax,这太大了。
  • 你完全没有利用你想要输出图像(预测前景/背景)的事实。
    • 例如,预测2,345与预测2,346和预测2,545高度相关,但你没有考虑到这一点

我建议你先阅读一些关于语义分割的资料:

  • 这篇论文:用于语义分割的全卷积网络
  • 这些幻灯片来自CS231n(斯坦福):特别是关于上采样和反卷积的部分

建议

如果你想使用TensorFlow,你需要从小处着手。首先尝试一个非常简单的网络,可能只有一个隐藏层。

你需要绘制所有张量的形状,以确保它们符合你的预期。例如,如果你绘制了tf.argmax(y, 1),你会发现形状是[batch_size, 200, 2],而不是预期的[batch_size, 200, 200]

TensorBoard是你的朋友,你应该尝试在这里绘制输入图像,以及你的预测结果,看看它们是什么样子。

尝试从小处着手,使用一个非常小的数据集,10张图像,看看你是否能过拟合它并几乎精确地预测响应。


总之,我对所有建议的准确性不确定,但它们值得一试,我希望这能帮助你在成功的道路上前进!

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注