Tensorflow: 保存验证误差最小的模型

我在Tensorflow上运行了一个训练任务,并得到了验证集损失的如下曲线。网络在第6000次迭代后开始过拟合。因此,我想得到过拟合之前的模型。

loss

我的训练代码大致如下:

train_step = ......summary = tf.scalar_summary(l1_loss.op.name, l1_loss)summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph)saver = tf.train.Saver()for i in xrange(20000):    batch = get_next_batch(batch_size)    sess.run(train_step, feed_dict = {x: batch.x, y:batch.y})    if (i+1) % 100 == 0:        saver.save(sess, "checkpoint/net", global_step = i+1)        summary_str = sess.run(summary, feed_dict=validation_feed_dict)        summary_writer.add_summary(summary_str, i+1)        summary_writer.flush()

训练完成后,仅保存了五个检查点(19600, 19700, 19800, 19900, 20000)。有什么方法可以让Tensorflow根据验证误差来保存检查点吗?

附注:我知道tf.train.Saver有一个max_to_keep参数,原则上可以保存所有检查点。但这并不是我想要的(除非这是唯一的选择)。我想让保存器保留到目前为止验证损失最小的检查点。这可能吗?


回答:

你需要计算验证集上的分类准确率,并跟踪到目前为止见到的最佳准确率,只有当验证准确率有所改善时才保存检查点。

如果数据集和/或模型很大,你可能需要将验证集分成批次,以适应内存计算。

这个教程展示了你想要做的具体方法:

https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb

它也可作为一个短视频提供:

https://www.youtube.com/watch?v=Lx8JUJROkh0

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注