TensorFlow: 恢复多个图

假设我们有两个TensorFlow计算图,G1G2,它们的保存权重分别为W1W2。假设我们通过构建G1G2来构建一个新的图G。我们如何为这个新图G恢复W1W2

举个简单的例子:

import tensorflow as tfV1 = tf.Variable(tf.zeros([1]))saver_1 = tf.train.Saver()V2 = tf.Variable(tf.zeros([1]))saver_2 = tf.train.Saver()sess = tf.Session()saver_1.restore(sess, 'W1')saver_2.restore(sess, 'W2')

在这个例子中,saver_1成功恢复了对应的V1,但saver_2在尝试恢复时会失败并抛出NotFoundError错误。


回答:

你可以使用两个保存器,每个保存器只查找一个变量。如果你只使用tf.train.Saver(),我想它会查找你定义的所有变量。你可以通过使用tf.train.Saver([v1, ...])来指定它查找的变量列表。有关更多信息,你可以阅读tf.train.Saver构造函数的文档:https://www.tensorflow.org/versions/r0.11/api_docs/python/state_ops.html#Saver

这里有一个简单的可工作的例子。假设你在”save_vars.py”文件中进行计算,并且文件中有以下代码:

import tensorflow as tf# 图1 - 将v1的值设置为[1.0]g1 = tf.Graph()with g1.as_default():    v1 = tf.Variable(tf.zeros([1]), name="v1")    assign1 = v1.assign(tf.constant([1.0]))    init1 = tf.initialize_all_variables()    save1 = tf.train.Saver()# 图2 - 将v2的值设置为[2.0]g2 = tf.Graph()with g2.as_default():    v2 = tf.Variable(tf.zeros([1]), name="v2")    assign2 = v2.assign(tf.constant([2.0]))    init2 = tf.initialize_all_variables()    save2 = tf.train.Saver()# 对图1进行计算并保存sess1 = tf.Session(graph=g1)sess1.run(init1)print sess1.run(assign1)save1.save(sess1, "tmp/v1.ckpt")# 对图2进行计算并保存sess2 = tf.Session(graph=g2)sess2.run(init2)print sess2.run(assign2)save2.save(sess2, "tmp/v2.ckpt")

如果你确保有一个tmp目录并运行python save_vars.py,你将得到保存的检查点文件。

现在,你可以使用名为”restore_vars.py”的文件来恢复,文件中包含以下代码:

import tensorflow as tf# 我们想要恢复的变量v1和v2v1 = tf.Variable(tf.zeros([1]), name="v1")v2 = tf.Variable(tf.zeros([1]), name="v2")# saver1只查找v1saver1 = tf.train.Saver([v1])# saver2只查找v2saver2 = tf.train.Saver([v2])with tf.Session() as sess:    saver1.restore(sess, "tmp/v1.ckpt")    saver2.restore(sess, "tmp/v2.ckpt")    print sess.run(v1)    print sess.run(v2)

当你运行python restore_vars.py时,输出应该是

[1.][2.]

(至少在我的电脑上是这样的输出)。如果有任何不清楚的地方,请随时发表评论。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注