### tf.Session()中的graph参数有什么作用?

我在理解tf.Session()中的graph参数时遇到了麻烦。我尝试在TensorFlow的网站上查找:链接,但没能理解太多内容。

我想弄清楚tf.Session()tf.Session(graph=some_graph_inserted_here)之间的区别。

问题背景

代码A(不工作):

def predict():    with tf.name_scope("predict"):        with tf.Session() as sess:            saver = tf.train.import_meta_graph("saved_models/testing.meta")            saver.restore(sess, "saved_models/testing")            loaded_graph = tf.get_default_graph()            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')            _x = loaded_graph.get_tensor_by_name('x:0')            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

这段代码在尝试加载图时给出了以下错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used,错误发生在saver = tf.train.import_meta_graph("saved_models/testing.meta")

代码B(工作):

def predict():    with tf.name_scope("predict"):        loaded_graph = tf.Graph()        with tf.Session(graph=loaded_graph) as sess:            saver = tf.train.import_meta_graph("saved_models/testing.meta")            saver.restore(sess, "saved_models/testing")            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')            _x = loaded_graph.get_tensor_by_name('x:0')            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

如果我用loaded_graph = tf.get_default_graph()替换loaded_graph = tf.Graph(),代码将无法工作。为什么?

如果有帮助,可以查看完整代码:(https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)


回答:

TensorFlow的Graph对象包含了各种tf.Tensortf.Operation

当你创建这些张量(例如使用tf.Variabletf.constant)或操作(例如tf.matmul)时,它们会被添加到默认图中(查看这些对象的graph成员属性可以得到它们所属的图)。如果你没有指定任何东西,它将是调用tf.get_default_graph方法时获取的图。

但你也可以使用上下文管理器处理多个图:

g = tf.Graph()with g.as_default():    [你的代码]

假设你在代码中创建了多个图,你需要将你想要运行的图作为tf.Session方法的参数传递,以指定TensorFlow运行哪个图。

在代码A中,你

  • 使用默认图,
  • 尝试将元图导入其中(因为它已经包含了一些节点,所以失败了),以及,
  • 将模型恢复到其中,

而在代码B中,你

  • 创建一个全新的图,
  • 将元图导入其中(因为它是一个空图,所以成功了),以及
  • 恢复它。

有用链接:

tf.Graph API

编辑:

这段代码可以使代码A工作(我将默认图重置为一个新的图,并移除了predict的name_scope)。

def predict():    tf.reset_default_graph()    with tf.Session() as sess:        saver = tf.train.import_meta_graph("saved_models/testing.meta")        saver.restore(sess, "saved_models/testing")        loaded_graph = tf.get_default_graph()        output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')        _x = loaded_graph.get_tensor_by_name('x:0')        print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注