假设我们有两个TensorFlow计算图,G1
和G2
,它们的保存权重分别为W1
和W2
。假设我们通过构建G1
和G2
来构建一个新的图G
。我们如何为这个新图G
恢复W1
和W2
?
举个简单的例子:
import tensorflow as tfV1 = tf.Variable(tf.zeros([1]))saver_1 = tf.train.Saver()V2 = tf.Variable(tf.zeros([1]))saver_2 = tf.train.Saver()sess = tf.Session()saver_1.restore(sess, 'W1')saver_2.restore(sess, 'W2')
在这个例子中,saver_1
成功恢复了对应的V1
,但saver_2
在尝试恢复时会失败并抛出NotFoundError
错误。
回答:
你可以使用两个保存器,每个保存器只查找一个变量。如果你只使用tf.train.Saver()
,我想它会查找你定义的所有变量。你可以通过使用tf.train.Saver([v1, ...])
来指定它查找的变量列表。有关更多信息,你可以阅读tf.train.Saver
构造函数的文档:https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver
这里有一个简单的可工作的例子。假设你在”save_vars.py”文件中进行计算,并且文件中有以下代码:
import tensorflow as tf# 图1 - 将v1的值设置为[1.0]g1 = tf.Graph()with g1.as_default(): v1 = tf.Variable(tf.zeros([1]), name="v1") assign1 = v1.assign(tf.constant([1.0])) init1 = tf.initialize_all_variables() save1 = tf.train.Saver()# 图2 - 将v2的值设置为[2.0]g2 = tf.Graph()with g2.as_default(): v2 = tf.Variable(tf.zeros([1]), name="v2") assign2 = v2.assign(tf.constant([2.0])) init2 = tf.initialize_all_variables() save2 = tf.train.Saver()# 对图1进行计算并保存sess1 = tf.Session(graph=g1)sess1.run(init1)print sess1.run(assign1)save1.save(sess1, "tmp/v1.ckpt")# 对图2进行计算并保存sess2 = tf.Session(graph=g2)sess2.run(init2)print sess2.run(assign2)save2.save(sess2, "tmp/v2.ckpt")
如果你确保有一个tmp
目录并运行python save_vars.py
,你将得到保存的检查点文件。
现在,你可以使用名为”restore_vars.py”的文件来恢复,文件中包含以下代码:
import tensorflow as tf# 我们想要恢复的变量v1和v2v1 = tf.Variable(tf.zeros([1]), name="v1")v2 = tf.Variable(tf.zeros([1]), name="v2")# saver1只查找v1saver1 = tf.train.Saver([v1])# saver2只查找v2saver2 = tf.train.Saver([v2])with tf.Session() as sess: saver1.restore(sess, "tmp/v1.ckpt") saver2.restore(sess, "tmp/v2.ckpt") print sess.run(v1) print sess.run(v2)
当你运行python restore_vars.py
时,输出应该是
[1.][2.]
(至少在我的电脑上是这样的输出)。如果有任何不清楚的地方,请随时发表评论。