在Tensorflow Estimator打印损失时使用的数据集是什么?

在使用Tensorflow Estimator时,它会在Python控制台上每100步打印一次全局步数和损失值(此外,它还会打印学习率、交叉熵和我的评估指标MAE,这些值会被打印在不同的行上,我认为这是由于某些包装函数所致,这些函数并不属于原始的Estimator API,因为我使用的是Google开发者的ResNet实现)。打印内容如下所示:

    I0530 19:20:42.748463 10964 tf_logging.py:116] learning_rate = 3.552962e-05, cross_entropy = 2.2080934, MAE = 5.135024 (62.295 sec)       I0530 19:20:42.749458 10964 tf_logging.py:116] loss = 2.2080934, step = 76066 (62.295 sec)

我的问题是,计算的是什么损失(或者计算的是什么MAE)?
是在特定步骤的日志记录发生时仅针对一个例子的损失吗?
是在特定步骤的日志记录发生时针对一个批次的损失吗?
还是整个训练集的损失?

另外,如果我的假设有误,请纠正我。我在这个领域还是新手。
谢谢。


回答:

tf.Estimator会自动为损失和全局步数设置一个LoggingTensorHook。据推测,你运行的代码为其他值(学习率、交叉熵(这实际上就是损失)和MAE)设置了单独的钩子,这就是为什么它们会被打印在不同的行上。

至于用于生成这些值的数据:它是“当前”的数据批次,即在进行日志记录的步骤中使用的数据批次。因此,在你提出的三个选项中,第二个选项是正确的。

这可以通过源代码确认,因为钩子会在运行后进行日志记录,接收到最后一次session.run()调用的结果(每次只获取一个批次)到run_values中。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注