在TensorFlow中使用同一个图表显示训练和验证准确度

我有一个TensorFlow模型,其中一部分用于评估准确度。该accuracy只是TensorFlow图中的另一个节点,它接受logitslabels作为输入。

当我想绘制训练准确度时,这很简单:我有类似这样的代码:

tf.scalar_summary("Training Accuracy", accuracy)tf.scalar_summary("SomethingElse", foo)summary_op = tf.merge_all_summaries()writer = tf.train.SummaryWriter('/me/mydir/', graph=sess.graph)

然后,在我的训练循环中,我有类似这样的代码:

for n in xrange(1000):  ...  summary, ..., ... = sess.run([summary_op, ..., ...], feed_dict)  writer.add_summary(summary, n)  ...

同样,在这个循环中,每隔100次迭代,我想评估验证准确度。我为此准备了一个单独的feed_dict,并且能够在Python中很好地评估验证准确度。

然而,我的困惑在于:我想为验证准确度创建另一个摘要,通过使用accuracy节点来实现。我不太清楚该如何操作。虽然我有accuracy节点,理论上我应该可以重用它,但我不知道具体如何操作,以便将验证准确度作为一个独立的scalar_summary输出…

这可能如何实现呢?


回答:

你可以重用accuracy节点,但你需要使用两个不同的SummaryWriters,一个用于训练运行,另一个用于测试数据。另外,你需要将准确度的标量摘要分配给一个变量。

accuracy_summary = tf.scalar_summary("Training Accuracy", accuracy)tf.scalar_summary("SomethingElse", foo)summary_op = tf.merge_all_summaries()summaries_dir = '/me/mydir/'train_writer = tf.train.SummaryWriter(summaries_dir + '/train', sess.graph)test_writer = tf.train.SummaryWriter(summaries_dir + '/test')

然后在你的训练循环中,你正常进行训练并使用train_writer记录你的摘要。此外,每100次迭代时,你在测试集上运行图表,并仅使用test_writer记录准确度摘要。

# 记录训练集摘要,并进行训练summary, _ = sess.run([summary_op, train_step], feed_dict=...)train_writer.add_summary(summary, n)if n % 100 == 0:  # 记录摘要和测试集准确度  summary, acc = sess.run([accuracy_summary, accuracy], feed_dict=...)  test_writer.add_summary(summary, n)  print('Accuracy at step %s: %s' % (n, acc))

然后你可以将TensorBoard指向父目录(summaries_dir),它会加载这两个数据集。

这也可以在TensorFlow的HowTo中找到 https://www.tensorflow.org/versions/r0.11/how_tos/summaries_and_tensorboard/index.html

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注