当我们在 Saver.save
中指定 global_step 时,它会将 global_step 存储为检查点的后缀。
# 保存检查点
saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)
我们可以这样恢复检查点并获取存储在检查点中的最后一个全局步骤:
# 恢复检查点并获取全局步骤
saver.restore(session, ckpt.model_checkpoint_path)
...
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
如果我们使用 tf.train.MonitoredTrainingSession
,如何等效地将全局步骤保存到检查点并获取 gstep
?
编辑 1
按照Maxim的建议,我在 tf.train.MonitoredTrainingSession
之前创建了 global_step
变量,并添加了一个 CheckpointSaverHook
,如下所示:
global_step = tf.train.get_or_create_global_step()
save_checkpoint_hook = tf.train.CheckpointSaverHook(checkpoint_dir=checkpoints_abs_path,
save_steps=5,
checkpoint_basename=(checkpoints_prefix + ".ckpt"))
with tf.train.MonitoredTrainingSession(master=server.target,
is_chief=is_chief,
hooks=[sync_replicas_hook, save_checkpoint_hook],
config=config) as session:
_, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
print("当前全局步骤=" + str(gstep))
我发现它生成了与 Saver.saver
类似的检查点文件。然而,它无法从检查点中检索全局步骤。请指导我如何修复这个问题?
回答:
您可以通过 tf.train.get_global_step()
或 tf.train.get_or_create_global_step()
函数获取当前的全局步骤。后者应在训练开始前调用。
对于监控会话,将 tf.train.CheckpointSaverHook
添加到 hooks
中,它会在每 N 步后使用定义的全局步骤张量内部保存模型。