我试图保存模型的迭代检查点,同时也保存在一个独立验证数据集上取得最佳分数的模型。然而,我的检查点会覆盖我的最佳模型。实际上,我使用的是类似于以下的代码:
saver = tf.train.Saver()with tf.Session() as sess: for epoch in range(20): # 训练模型 [...] # 并保存检查点 saver.save(sess, "iter", global_step=epoch) if best_validiation_acc < last_validation_acc: saver.save(sess, "best_model")
如何防止我的迭代保存覆盖最佳模型?
回答:
原因是你对两者使用了同一个tf.train.Saver
,因此它会记住最后的max_to_keep=5
个检查点文件,无论你如何命名它们。
最简单的解决方案是设置max_to_keep=None
,这将强制保存器保留所有检查点,并且不会覆盖任何内容。然而,你可能更希望至少覆盖迭代检查点。在这种情况下,解决方案是:
iter_saver = tf.train.Saver(max_to_keep=3) # 保留最后3次迭代best_saver = tf.train.Saver(max_to_keep=5) # 保留最后5个最佳模型with tf.Session() as sess: for epoch in range(20): # 训练模型 [...] # 并保存检查点 iter_saver.save(sess, "iter/model", global_step=epoch) if best_validiation_acc < last_validation_acc: best_saver.save(sess, "best/model")
我还建议使用不同的目录,这样checkpoint
文件就不会发生冲突。