我尝试使用TensorFlow保存模型并在另一个文件中恢复。我使用以下代码进行训练和保存模型。
import input_dataimport osimport tensorflow as tfmnist = input_data.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder("float", shape=[None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)y_ = tf.placeholder("float", shape=[None, 10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)saver = tf.train.Saver()# train data and get results for batchesinit = tf.global_variables_initializer()sess = tf.Session()sess.run(init)# train the datafor i in range(10): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})print(batch_xs)correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))print ("accuracy", sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))prediction=tf.argmax(y,1)arr=prediction.eval(feed_dict={x: mnist.test.images}, session=sess)#print ("predictions", )#for i in range(len(arr)): #print(arr[i])save_path = saver.save(sess, '/model.ckpt')print ('Model saved in file: ', save_path)
然后尝试使用以下代码进行恢复。
import input_dataimport osimport tensorflow as tf#mnist = input_data.read_data_sets('MNIST_data', one_hot=True)x = tf.placeholder("float", shape=[None, 784])W = tf.Variable(tf.zeros([784,10]))b = tf.Variable(tf.zeros([10]))y = tf.nn.softmax(tf.matmul(x,W) + b)y_ = tf.placeholder("float", shape=[None, 10])cross_entropy = -tf.reduce_sum(y_*tf.log(y))train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy)init_op = tf.global_variables_initializer()saver = tf.train.Saver()tf.train.NewCheckpointReader("./model.ckpt")with tf.Session() as sess: sess.run(init_op) #print("sess.run") saver.restore(sess, "./model.ckpt") print ("Model restored.")
在saver.restore(sess, "./model.ckpt")
这一行,我遇到了NotFoundError。错误信息如下:
---------------------------------------------------------------------------NotFoundError Traceback (most recent call last)C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args) 1020 try:-> 1021 return fn(*args) 1022 except errors.OpError as e:C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run_fn(session, feed_dict, fetch_list, target_list, options, run_metadata) 1002 feed_dict, fetch_list, target_list,-> 1003 status, run_metadata) 1004 C:\Anaconda3\envs\tensorflow\lib\contextlib.py in __exit__(self, type, value, traceback) 65 try:---> 66 next(self.gen) 67 except StopIteration:C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\framework\errors_impl.py in raise_exception_on_not_ok_status() 468 compat.as_text(pywrap_tensorflow.TF_Message(status)),--> 469 pywrap_tensorflow.TF_GetCode(status)) 470 finally:NotFoundError: Key y_3 not found in checkpoint [[Node: save_16/RestoreV2_47 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_16/Const_0, save_16/RestoreV2_47/tensor_names, save_16/RestoreV2_47/shape_and_slices)]]During handling of the above exception, another exception occurred:NotFoundError Traceback (most recent call last)<ipython-input-42-17503962c118> in <module>() 17 sess.run(init_op) 18 #print("sess.run")---> 19 saver.restore(sess, "./model.ckpt") 20 print ("Model restored.")C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\training\saver.py in restore(self, sess, save_path) 1386 return 1387 sess.run(self.saver_def.restore_op_name,-> 1388 {self.saver_def.filename_tensor_name: save_path}) 1389 1390 @staticmethodC:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata) 764 try: 765 result = self._run(None, fetches, feed_dict, options_ptr,--> 766 run_metadata_ptr) 767 if run_metadata: 768 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata) 962 if final_fetches or final_targets: 963 results = self._do_run(handle, final_targets, final_fetches,--> 964 feed_dict_string, options, run_metadata) 965 else: 966 results = []C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata) 1012 if handle is None: 1013 return self._do_call(_run_fn, self._session, feed_dict, fetch_list,-> 1014 target_list, options, run_metadata) 1015 else: 1016 return self._do_call(_prun_fn, self._session, handle, feed_dict,C:\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args) 1032 except KeyError: 1033 pass-> 1034 raise type(e)(node_def, op, message) 1035 1036 def _extend_graph(self):NotFoundError: Key y_3 not found in checkpoint [[Node: save_16/RestoreV2_47 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save_16/Const_0, save_16/RestoreV2_47/tensor_names, save_16/RestoreV2_47/shape_and_slices)]]
在saver.restore(sess, "./model.ckpt")
这一行,我遇到了NotFoundError。错误信息显示在检查点中找不到键y_3。这可能是因为在保存和恢复模型时,变量的名称或结构不一致导致的。请检查保存和恢复代码中的变量定义和名称是否一致,并确保在恢复模型时使用了正确的路径和文件名。
回答: