我有一个名为neural.py
的模块
我在模块主体中初始化变量。
import tensorflow as tf tf_x = tf.placeholder(tf.float32, [None, length])tf_y = tf.placeholder(tf.float32, [None, num_classes])...
我在train()
函数中训练后保存检查点:
def train(): ... pred = tf.layers.dense(dropout, num_classes, tf.identity) ... cross_entropy = tf.losses.softmax_cross_entropy(tf_y, pred) ... with tf.Session() as sess: init = tf.global_variables_initializer() sess.run(init) saver = tf.train.Saver(tf.trainable_variables()) for ep in range(epochs): ... (训练步骤)... saver.save(sess, "checkpoints/cnn")
我想在同一模块的run()
函数中,在训练后恢复并运行网络:
def run(): # 我尝试在这里添加tf.reset_default_graph() # 我也尝试使用tf.Graph().as_default() as g:并在tf.Session()中添加(graph=g) saver = tf.train.Saver() with tf.Session() as sess: saver.restore(sess, "checkpoints/cnn") ... (运行网络等)
但这不起作用。它要么给我NotFoundError (see above for traceback): Key beta2_power not found in checkpoint
错误,要么如果我在run()
下添加tf.reset_default_graph()
,如上所述,会得到ValueError: No variables to save
错误。
然而,如果我在没有train()
的新模块中放入完全相同的run()
代码,并且在顶部添加tf.reset_default_graph()
,它可以完美工作。我如何在同一模块中使其工作?
最终代码片段:
if __name__ == '__main__': print("开始训练") train() print("训练完成。生成预测") run()
回答:
这可能是打字错误,但saver.save(sess, "checkpoints/cnn")
绝对应该在with tf.Session() as sess
块内,否则你保存的是一个已关闭的会话。
NotFoundError (see above for traceback): Key beta2_power not found in checkpoint
我认为问题在于你的图的一部分是在train
中定义的。beta1_power
和beta2_power
是AdapOptimizer
的内部变量,它们与pred
和softmax_cross_entropy
一起,如果没有调用train()
(例如被注释掉?),就不在图中。所以一个解决方案是使整个图在train
和run
中都可访问。
另一个解决方案是将它们分开,并在run
中使用恢复的图,而不是默认的图。如下所示:
tf.reset_default_graph()saver = tf.train.import_meta_graph('checkpoints/cnn.meta')with tf.Session() as sess: saver.restore(sess, "checkpoints/cnn") print("模型已恢复。") tf_x = sess.graph.get_tensor_by_name('tf_x:0') ...
但你需要为所有变量命名(无论如何这是一个好主意),然后在图中找到这些张量。在这里不能使用之前定义的变量。这种方法确保run
方法使用保存的模型版本,可以轻松地提取到单独的脚本中等。