我对机器学习和Tensorflow框架还比较新手。我尝试使用这里展示的代码作为参考,训练了一个模型,使用MNIST手写数字数据集,并希望对自己创建的测试样本进行推理。然而,我是在一台带有GPU的远程机器上进行训练的,我试图将数据保存到一个目录中,以便将数据传输到本地机器上进行推理。
我似乎能够使用tf.saved_model.simple_save
保存部分模型,但是我不确定如何使用保存的数据进行推理,以及如何使用这些数据对新图像进行预测。看起来保存模型的方式有很多种,但我不知道在Tensorflow框架中使用哪种方法是常规的或“正确”的方式。
到目前为止,我认为我需要的代码行如下,但不确定是否正确。
tf.saved_model.simple_save(sess, 'mnist_model', inputs={'x': self.x}, outputs={'y_': self.y_, 'y_conv':self.y_conv})
如果有人能指导我如何正确保存训练模型以及使用哪些变量来使用保存的模型进行推理,我将不胜感激。
回答:
你可以这样做:在你的图定义中创建一个tf.train.Saver()
对象,然后使用它将网络保存到指定的目录中。然后可以从远程机器下载该目录中的权重到本地机器并在本地恢复。这里是一个小型示例网络:
import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets('MNIST_data', one_hot=True)# >>>> 配置变量 <<<<TRAIN_STEPS = 1000SAVE_EVERY = 100# >>>> 网络 <<<<inputs = tf.placeholder(tf.float32, shape=[None, 784])labels = tf.placeholder(tf.float32, shape=[None, 10])h1 = tf.layers.dense(inputs, 256, activation=tf.nn.relu, use_bias=True)logits = tf.layers.dense(h1, 10, use_bias=True)predictions = tf.nn.softmax(logits)prediction_ids = tf.argmax(predictions, axis=1)# >>>> 损失与优化 <<<<loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=labels, logits=logits)opt = tf.train.AdamOptimizer().minimize(loss)# >>>> 工具 <<<<init = tf.global_variables_initializer()saver = tf.train.Saver()with tf.Session() as sess: sess.run(init) # >>>> 训练 - 在远程机器上运行,注释掉本地部分 <<<< for i in range(TRAIN_STEPS): print("训练步骤 {}".format(i), end="\r") batch_data, batch_labels = mnist.train.next_batch(batch_size=128) feed_dict = { inputs: batch_data, labels: batch_labels } l, _ = sess.run([loss, opt], feed_dict=feed_dict) if i % SAVE_EVERY == 0: saver.save(sess, "saved_model/network_weights.ckpt") # >>>> 使用网络 - 在本地运行以使用网络 <<< saver.restore(sess, "saved_model/network_weights.ckpt") test_data, test_labels = mnist.test.images, mnist.test.labels feed_dict = { inputs: test_data, labels: test_labels } preds = sess.run(prediction_ids, feed_dict=feed_dict) print(preds)
所以一旦你在网络中定义了saver,你就可以使用它将权重保存到指定的目录中 – 在这个例子中是“saved_models”目录,你需要在运行这段代码之前创建这个目录。
恢复模型只需调用saver.restore()
并传递会话和权重存储的路径即可。因此,你可以在远程机器上运行这段代码,下载“saved_models”目录到本地机器,然后运行这段代码并注释掉训练部分来实际使用模型。