Tensorflow: 如何在一个模块中同时使用saver.save()和.restore()?

我有一个名为neural.py的模块

我在模块主体中初始化变量。

import tensorflow as tf tf_x = tf.placeholder(tf.float32, [None, length])tf_y = tf.placeholder(tf.float32, [None, num_classes])...

我在train()函数中训练后保存检查点:

def train():    ...    pred = tf.layers.dense(dropout, num_classes, tf.identity)     ...    cross_entropy = tf.losses.softmax_cross_entropy(tf_y, pred)    ...    with tf.Session() as sess:        init = tf.global_variables_initializer()        sess.run(init)        saver = tf.train.Saver(tf.trainable_variables())        for ep in range(epochs):        ... (训练步骤)...    saver.save(sess, "checkpoints/cnn") 

我想在同一模块的run()函数中,在训练后恢复并运行网络:

def run():    # 我尝试在这里添加tf.reset_default_graph()    # 我也尝试使用tf.Graph().as_default() as g:并在tf.Session()中添加(graph=g)    saver = tf.train.Saver()    with tf.Session() as sess:        saver.restore(sess, "checkpoints/cnn")        ... (运行网络等)

但这不起作用。它要么给我NotFoundError (see above for traceback): Key beta2_power not found in checkpoint错误,要么如果我在run()下添加tf.reset_default_graph(),如上所述,会得到ValueError: No variables to save错误。

然而,如果我在没有train()的新模块中放入完全相同的run()代码,并且在顶部添加tf.reset_default_graph(),它可以完美工作。我如何在同一模块中使其工作?

最终代码片段:

if __name__ == '__main__':    print("开始训练")    train()    print("训练完成。生成预测")    run()

回答:

这可能是打字错误,但saver.save(sess, "checkpoints/cnn")绝对应该在with tf.Session() as sess块内,否则你保存的是一个已关闭的会话。

NotFoundError (see above for traceback): Key beta2_power not found in checkpoint

我认为问题在于你的图的一部分是在train中定义的。beta1_powerbeta2_powerAdapOptimizer的内部变量,它们与predsoftmax_cross_entropy一起,如果没有调用train()(例如被注释掉?),就不在图中。所以一个解决方案是使整个图在trainrun中都可访问。

另一个解决方案是将它们分开,并在run中使用恢复的图,而不是默认的图。如下所示:

tf.reset_default_graph()saver = tf.train.import_meta_graph('checkpoints/cnn.meta')with tf.Session() as sess:  saver.restore(sess, "checkpoints/cnn")  print("模型已恢复。")  tf_x = sess.graph.get_tensor_by_name('tf_x:0')  ...

但你需要为所有变量命名(无论如何这是一个好主意),然后在图中找到这些张量。在这里不能使用之前定义的变量。这种方法确保run方法使用保存的模型版本,可以轻松地提取到单独的脚本中等。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注