我尝试加载一个之前保存的模型,并在重新训练后保存它。加载过程进行得很顺利,但保存时遇到了如下问题:
sess=tf.Session()sess.run(init)loader = tf.train.import_meta_graph(self.model_path+'.meta')loader.restore(sess,self.model_path)#tf.train.latest_checkpoint('./')) print('Model restored')#retrainsaver=tf.train.Saver()saver.save(sess, self.model_path)
第一次保存时我没有遇到类似的任何问题,如下所示:
saver=tf.train.Saver()sess=tf.Session()sess.run(init)#trainsaver.save(sess, self.model_path)
我遇到的错误是:
File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1139, in __init__ self.build() File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1170, in build restore_sequentially=self._restore_sequentially) File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 673, in build saveables = self._ValidateAndSliceInputs(names_to_saveables) File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 557, in _ValidateAndSliceInputs names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables) File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 535, in OpListToDict name)ValueError: At least two variables have the same name: Variable_15/Adam
回答:
您看到这条消息是因为在作用域中有两个变量名称相同。tf.train.import_meta_graph从文件中读取图并将所有操作和张量添加到当前现有图中。我很惊讶import_meta_graph一开始甚至没有触发这样的异常。
请查看完整的示例以重现此行为:
import tensorflow as tf# tiny graphx = tf.placeholder(tf.float32, shape=[1, 2], name='input')output = tf.identity(tf.layers.dense(x, 1), name='output')cost = tf.reduce_sum(x * output)# create first time u'beta1_power:0', u'beta2_power:0'train_op = tf.train.AdamOptimizer().minimize(cost)with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver = tf.train.Saver(tf.global_variables()) saver.save(sess, './adam/my_model') print([v.name for v in tf.global_variables()]) # create second time u'beta1_power:0', u'beta2_power:0' meta_graph = tf.train.import_meta_graph('./adam/my_model.meta') meta_graph.restore(sess, './adam/my_model') print([v.name for v in tf.global_variables()]) saver = tf.train.Saver(tf.global_variables()) # exception as there are now two times: u'beta1_power:0', u'beta2_power:0' saver.save(sess, './adam/my_model2')
解决方案包括:
- 在
tf.train.import_meta_graph之前使用tf.reset_default_graph()清除图 - 为
tf.train.import_meta_graph使用一个新会话 - 仅使用
tf.train.Saver().restore(sess, '/tmp/model/my_model')加载权重