我有一个使用Tensorflow的简单神经网络。以下是会话代码:
with tensorFlow.Session() as sess: sess.run(tensorFlow.global_variables_initializer()) for epoch in range(epochs): i = 0 epochLoss = 0 for _ in range(int(len(data) / batchSize)): ex, ey = nextBatch(i) i += 1 feedDict = {x :ex, y:ey } _, cos = sess.run([optimizer,cost], feed_dict= feedDict) epochLoss += cos / (int(len(data)) / batchSize) print("Epoch", epoch + 1, "completed out of", epochs, "loss:", "{:.9f}".format(epochLoss)) save_path = saver.save(sess, "model.ckpt") print("Model saved in file: %s" % save_path)
在最后两行,我保存了模型,并在另一个类中恢复了图形:
with new_graph.as_default(): with tf.Session(graph=new_graph) as sess: sess.run(tf.global_variables_initializer()) new_saver = tf.train.import_meta_graph('model.ckpt.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./'))
我想重新训练模型,这意味着不初始化权重,只是从上次停止的地方更新它们。
我该怎么做呢?
回答:
来自 https://www.tensorflow.org/api_docs/python/state_ops/saving_and_restoring_variables
tf.train.Saver.restore(sess, save_path)
恢复先前保存的变量。
此方法运行构造函数添加的用于恢复变量的操作。它需要一个已经启动图形的会话。要恢复的变量不必已经初始化,因为恢复本身就是一种初始化变量的方式。
以下示例来自 https://www.tensorflow.org/how_tos/variables/
# 创建一些变量。v1 = tf.Variable(..., name="v1")v2 = tf.Variable(..., name="v2")...# 添加保存和恢复所有变量的操作。saver = tf.train.Saver()# 稍后,启动模型,使用saver从磁盘恢复变量,并# 使用模型进行一些工作。with tf.Session() as sess: # 从磁盘恢复变量。 saver.restore(sess, "/tmp/model.ckpt") print("模型已恢复。") # 使用模型进行一些工作 ...