在TensorFlow中使用MonitoredTrainingSession与Estimator的原因

我看到许多例子中使用MonitoredTrainingSessiontf.Estimator作为训练框架。然而,我不清楚为什么要选择其中一个而不是另一个。两者都可以通过SessionRunHooks进行配置。两者都与tf.data.Dataset迭代器集成,并且可以输入训练/验证数据集。我不确定使用其中一种设置的优势是什么。


回答:

简短的回答是,MonitoredTrainingSession允许用户访问图形和会话对象以及训练循环,而Estimator则隐藏了图形和会话的细节,通常情况下,使得运行训练变得更加容易,特别是使用train_and_evaluate时,如果你需要定期进行评估的话。

MonitoredTrainingSession与普通的tf.Session()不同之处在于,它处理变量初始化,设置文件写入器,并且还包含分布式训练的功能。

另一方面,Estimator API是一个高级构造,类似于Keras。它可能在示例中使用较少,因为它是后来引入的。它还允许通过DistibutedStrategy来分布式训练/评估,并且它有一些预设的估算器,可以快速原型设计。

在模型定义方面,它们相当平等,两者都允许使用keras.layers,或者从头开始定义完全自定义的模型。因此,如果出于任何原因,你需要访问图形构建或自定义训练循环,请使用MonitoredTrainingSession。如果你只是想定义模型,训练它,运行验证和预测,而不需要额外的复杂性和样板代码,请使用Estimator

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注