保存重新训练的TensorFlow模型时遇到问题

我尝试加载一个之前保存的模型,并在重新训练后保存它。加载过程进行得很顺利,但保存时遇到了如下问题:

sess=tf.Session()sess.run(init)loader = tf.train.import_meta_graph(self.model_path+'.meta')loader.restore(sess,self.model_path)#tf.train.latest_checkpoint('./'))            print('Model restored')#retrainsaver=tf.train.Saver()saver.save(sess, self.model_path)

第一次保存时我没有遇到类似的任何问题,如下所示:

saver=tf.train.Saver()sess=tf.Session()sess.run(init)#trainsaver.save(sess, self.model_path)

我遇到的错误是:

File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1139, in __init__    self.build()  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 1170, in build    restore_sequentially=self._restore_sequentially)  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 673, in build    saveables = self._ValidateAndSliceInputs(names_to_saveables)  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 557, in _ValidateAndSliceInputs    names_to_saveables = BaseSaverBuilder.OpListToDict(names_to_saveables)  File "/share/apps/python2.7-tensorflow-1.2.1/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 535, in OpListToDict    name)ValueError: At least two variables have the same name: Variable_15/Adam

回答:

您看到这条消息是因为在作用域中有两个变量名称相同。tf.train.import_meta_graph从文件中读取图并将所有操作和张量添加到当前现有图中。我很惊讶import_meta_graph一开始甚至没有触发这样的异常。

请查看完整的示例以重现此行为:

import tensorflow as tf# tiny graphx = tf.placeholder(tf.float32, shape=[1, 2], name='input')output = tf.identity(tf.layers.dense(x, 1), name='output')cost = tf.reduce_sum(x * output)# create first time u'beta1_power:0', u'beta2_power:0'train_op = tf.train.AdamOptimizer().minimize(cost)with tf.Session() as sess:    sess.run(tf.global_variables_initializer())    saver = tf.train.Saver(tf.global_variables())    saver.save(sess, './adam/my_model')    print([v.name for v in tf.global_variables()])    # create second time u'beta1_power:0', u'beta2_power:0'    meta_graph = tf.train.import_meta_graph('./adam/my_model.meta')    meta_graph.restore(sess, './adam/my_model')    print([v.name for v in tf.global_variables()])    saver = tf.train.Saver(tf.global_variables())    # exception as there are now two times: u'beta1_power:0', u'beta2_power:0'    saver.save(sess, './adam/my_model2')

解决方案包括:

  • tf.train.import_meta_graph之前使用tf.reset_default_graph()清除图
  • tf.train.import_meta_graph使用一个新会话
  • 仅使用tf.train.Saver().restore(sess, '/tmp/model/my_model')加载权重

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注