如何使用 tf.reset_default_graph()

每当我尝试使用 tf.reset_default_graph() 时,我都会遇到这个错误:IndexError: list index out of range 或者 “。我应该在代码的哪个部分使用这个函数?我应该在什么时候使用它?

编辑:

我更新了代码,但错误仍然存在。

def evaluate():    with tf.name_scope("loss"):        global x # x 是 tf.placeholder()        xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=neural_network(x))        loss = tf.reduce_mean(xentropy, name="loss")    with tf.name_scope("train"):        optimizer = tf.train.AdamOptimizer()        training_op = optimizer.minimize(loss)    with tf.name_scope("exec"):        with tf.Session() as sess:            for i in range(1, 2):                sess.run(tf.global_variables_initializer())                sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})                print "Training " + str(i)                saver = tf.train.Saver()                saver.save(sess, "saved_models/testing")                print "Model Saved."def predict():    with tf.name_scope("predict"):        tf.reset_default_graph()        with tf.Session() as sess:            saver = tf.train.import_meta_graph("saved_models/testing.meta")            saver.restore(sess, "saved_models/testing")            output_ = tf.get_default_graph().get_tensor_by_name('output_layer:0')            print sess.run(output_, feed_dict={x: np.array([12003]).reshape([-1, 1])})def main():    print "Starting Program..."    evaluate()    writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())    predict()

如果我从更新后的代码中移除 tf.reset_default_graph(),我会得到这个错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used

根据我目前的理解,tf.reset_default_graph() 会移除所有图形,因此我避免了上述提到的错误(ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used


回答:

这可能是你使用它的方法:

import tensorflow as tfa = tf.constant(1)with tf.Session() as sess:    tf.reset_default_graph()

你会遇到错误是因为你在会话中使用它。根据 tf.reset_default_graph() 的文档:

在活动的 tf.Session 或 tf.InteractiveSession 中调用此函数将导致未定义的行为。在调用此函数后使用任何先前创建的 tf.Operation 或 tf.Tensor 对象将导致未定义的行为


tf.reset_default_graph() 在测试阶段可能很有帮助(至少对我有帮助),当我在 jupyter 笔记本中进行实验时。然而,我从未在生产环境中使用过它,并且不明白它在那里会有什么帮助。

这是一个可能出现在笔记本中的示例:

import tensorflow as tf# 创建一些图形with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    print sess.run(...)

现在我不再需要这些东西了,但如果我创建另一个图形并在 tensorboard 中可视化它,我会看到旧节点和新节点。为了解决这个问题,我可以重启内核并只运行下一个单元格。然而,我可以简单地这样做:

tf.reset_default_graph()# 创建一个新图形with tf.Session() as sess:    print sess.run(...)

在 OP 添加了他的代码后的编辑

with tf.name_scope("predict"):    tf.reset_default_graph()

这里大约发生了什么。你的代码失败了,因为 tf.name_scope 已经向图形中添加了一些东西。在这个“向图形中添加东西”的过程中,你告诉 TF 完全移除图形,但它不能,因为它正忙于添加东西。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注