我想在TensorFlow中只恢复计算图的一部分。我的架构包含两个网络。第一个网络的输出是第二个网络的输入。第一个网络已经预训练,我希望从一个检查点中恢复。我不希望更新第一个网络的参数。有没有我可以遵循的例子来实现这一点?
谢谢
回答:
我没有您任务的具体代码,但这里有一个简短的指南,可能会对您有所帮助:
首先,您需要将您的网络解析为tf.GraphDef
格式,代码应该如下所示:
graph_def = tf.GraphDef()with tf.gfile.FastGFile("path/to/graphdef") as f: s = f.read()graph_def.ParseFromString(s)
或者从检查点/保存的模式中恢复,然后转换为GraphDef
,方法如下:
tf.train.import_meta_graph('checkpoint.meta')tf.get_default_graph().as_graph_def()
现在您已经有了graph_def
其次,使用tf.graph_util.extract_sub_graph
从graph_def
中提取子图,您可以指定作为第二个网络输入的目标节点。
最后,使用tf.import_graph_def
导入第二步中的子图。
另外,由于您不希望更新第一个网络的参数,您可以使用tf.graph_util.convert_variables_to_constants
来冻结其参数