在TensorFlow模型恢复时遇到NotFoundError

我尝试使用TensorFlow保存模型并在另一个文件中恢复。我使用以下代码进行训练和保存模型。

import input_dataimport osimport tensorflow as tfmnist = input_data.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder("float", shape=[None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)y_ = tf.placeholder("float", shape=[None, 10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)saver = tf.train.Saver()# train data and get results for batchesinit = tf.global_variables_initializer()sess = tf.Session()sess.run(init)# train the datafor i in range(10):    batch_xs, batch_ys = mnist.train.next_batch(100)    sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})print(batch_xs)correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print ("accuracy", sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))prediction=tf.argmax(y,1)arr=prediction.eval(feed_dict={x: mnist.test.images}, session=sess)#print ("predictions", )#for i in range(len(arr)):    #print(arr[i])save_path = saver.save(sess, '/model.ckpt')print ('Model saved in file: ', save_path)

然后尝试使用以下代码进行恢复。

import input_dataimport osimport tensorflow as tf#mnist = input_data.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder("float", shape=[None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)y_ = tf.placeholder("float", shape=[None, 10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init_op = tf.global_variables_initializer()saver = tf.train.Saver()tf.train.NewCheckpointReader("./model.ckpt")with tf.Session() as sess:        sess.run(init_op)        #print("sess.run")        saver.restore(sess, "./model.ckpt")        print ("Model restored.")

saver.restore(sess, "./model.ckpt")这一行,我遇到了NotFoundError。错误信息如下:

---------------------------------------------------------------------------NotFoundError                             Traceback (most recent call last)C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)   1020     try:-> 1021       return fn(*args)   1022     except errors.OpError as e:C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata)   1002                                  feed_dict, fetch_list, target_list,-> 1003                                  status, run_metadata)   1004 C:\Anaconda3\envs\tensorflow\lib\contextlib.py in __exit__(self, type, value, traceback)     65             try:---> 66                 next(self.gen)     67             except StopIteration:C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status()    468           compat.as_text(pywrap_tensorflow.TF_Message(status)),--> 469           pywrap_tensorflow.TF_GetCode(status))    470   finally:NotFoundError: Key y_3 not found in checkpoint     [[Node: save_16/RestoreV2_47 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_16/Const_0, save_16/RestoreV2_47/tensor_names, save_16/RestoreV2_47/shape_and_slices)]]During handling of the above exception, another exception occurred:NotFoundError                             Traceback (most recent call last)<ipython-input-42-17503962c118> in <module>()     17         sess.run(init_op)     18         #print("sess.run")---> 19         saver.restore(sess, "./model.ckpt")     20         print ("Model restored.")C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py in restore(self, sess, save_path)   1386       return   1387     sess.run(self.saver_def.restore_op_name,-> 1388              {self.saver_def.filename_tensor_name: save_path})   1389    1390   @staticmethodC:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)    764     try:    765       result = self._run(None, fetches, feed_dict, options_ptr,--> 766                          run_metadata_ptr)    767       if run_metadata:    768         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)    962     if final_fetches or final_targets:    963       results = self._do_run(handle, final_targets, final_fetches,--> 964                              feed_dict_string, options, run_metadata)    965     else:    966       results = []C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)   1012     if handle is None:   1013       return self._do_call(_run_fn, self._session, feed_dict, fetch_list,-> 1014                            target_list, options, run_metadata)   1015     else:   1016       return self._do_call(_prun_fn, self._session, handle, feed_dict,C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)   1032         except KeyError:   1033           pass-> 1034       raise type(e)(node_def, op, message)   1035    1036   def _extend_graph(self):NotFoundError: Key y_3 not found in checkpoint     [[Node: save_16/RestoreV2_47 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_16/Const_0, save_16/RestoreV2_47/tensor_names, save_16/RestoreV2_47/shape_and_slices)]]

saver.restore(sess, "./model.ckpt")这一行,我遇到了NotFoundError。错误信息显示在检查点中找不到键y_3。这可能是因为在保存和恢复模型时,变量的名称或结构不一致导致的。请检查保存和恢复代码中的变量定义和名称是否一致,并确保在恢复模型时使用了正确的路径和文件名。


回答:

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注