我试图保存一些变量并查看是否可以稍后恢复它们。以下是我的保存代码:
import tensorflow as tf; my_a = tf.Variable(2,name = "my_a"); my_b = tf.Variable(3,name = "my_b"); my_c = tf.Variable(4,name = "my_c"); my_c = tf.add(my_a,my_b); with tf.Session() as sess: init = tf.initialize_all_variables(); sess.run(init); print("my_c = ",sess.run(my_c)); saver = tf.train.Saver(); saver.save(sess,"test.ckpt");
这会输出:
my_c = 5
当我恢复它时:
import tensorflow as tf; c = tf.Variable(3100,dtype = tf.int32); with tf.Session() as sess: sess.run(tf.initialize_all_variables()); saver = tf.train.Saver({"my_c":c}); saver.restore(sess, "test.ckpt"); cc= sess.run(c); print(cc);
这会给我:
4
恢复的 my_c 值应该是 5,因为它是 my_a 和 my_b 的和。然而,它给我的是 4,这是 my_c 的初始化值。能有人解释为什么会这样吗,以及如何保存变量的更改?
回答:
在你的原始代码中,你实际上并没有将名为 my_c
的变量(请注意,TensorFlow 的 name
)赋值为 my_a + my_b
。
通过编写 my_c = tf.add(my_a,my_b)
,Python 变量 my_c
现在与具有 name='my_c'
的 tf.Variable
不同。
当你执行 sess.run()
时,你只是在执行操作,并没有更新那个变量。
如果你想让这段代码正确运行,请使用以下代码 – (查看注释中的更改)
import tensorflow as tfmy_a = tf.Variable(2,name = "my_a")my_b = tf.Variable(3,name = "my_b")my_c = tf.Variable(4,name="my_c")# 使用 assign() 函数设置新值add = my_c.assign(tf.add(my_a,my_b))with tf.Session() as sess: init = tf.initialize_all_variables() sess.run(init) # 执行 add 操作符 sess.run(add) print("my_c = ",sess.run(my_c)) saver = tf.train.Saver() saver.save(sess,"test.ckpt")