我在理解tf.Session()
中的graph参数时遇到了麻烦。我尝试在TensorFlow的网站上查找:链接,但没能理解太多内容。
我想弄清楚tf.Session()
和tf.Session(graph=some_graph_inserted_here)
之间的区别。
问题背景
代码A(不工作):
def predict(): with tf.name_scope("predict"): with tf.Session() as sess: saver = tf.train.import_meta_graph("saved_models/testing.meta") saver.restore(sess, "saved_models/testing") loaded_graph = tf.get_default_graph() output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0') _x = loaded_graph.get_tensor_by_name('x:0') print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})
这段代码在尝试加载图时给出了以下错误:ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used
,错误发生在saver = tf.train.import_meta_graph("saved_models/testing.meta")
处
代码B(工作):
def predict(): with tf.name_scope("predict"): loaded_graph = tf.Graph() with tf.Session(graph=loaded_graph) as sess: saver = tf.train.import_meta_graph("saved_models/testing.meta") saver.restore(sess, "saved_models/testing") output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0') _x = loaded_graph.get_tensor_by_name('x:0') print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})
如果我用loaded_graph = tf.get_default_graph()
替换loaded_graph = tf.Graph()
,代码将无法工作。为什么?
如果有帮助,可以查看完整代码:(https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)
回答:
TensorFlow的Graph
对象包含了各种tf.Tensor
和tf.Operation
。
当你创建这些张量(例如使用tf.Variable
或tf.constant
)或操作(例如tf.matmul
)时,它们会被添加到默认图中(查看这些对象的graph
成员属性可以得到它们所属的图)。如果你没有指定任何东西,它将是调用tf.get_default_graph
方法时获取的图。
但你也可以使用上下文管理器处理多个图:
g = tf.Graph()with g.as_default(): [你的代码]
假设你在代码中创建了多个图,你需要将你想要运行的图作为tf.Session
方法的参数传递,以指定TensorFlow运行哪个图。
在代码A中,你
- 使用默认图,
- 尝试将元图导入其中(因为它已经包含了一些节点,所以失败了),以及,
- 将模型恢复到其中,
而在代码B中,你
- 创建一个全新的图,
- 将元图导入其中(因为它是一个空图,所以成功了),以及
- 恢复它。
有用链接:
编辑:
这段代码可以使代码A工作(我将默认图重置为一个新的图,并移除了predict的name_scope
)。
def predict(): tf.reset_default_graph() with tf.Session() as sess: saver = tf.train.import_meta_graph("saved_models/testing.meta") saver.restore(sess, "saved_models/testing") loaded_graph = tf.get_default_graph() output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0') _x = loaded_graph.get_tensor_by_name('x:0') print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))