我刚刚使用TensorFlow训练了一个用于识别太阳黑子的CNN模型。我的模型与这个非常相似。问题在于,我找不到任何关于如何使用训练阶段生成的检查点进行预测的清晰说明。
我尝试使用标准的恢复方法:
saver = tf.train.import_meta_graph('./model/model.ckpt.meta')saver.restore(sess,'./model/model.ckpt')
但随后我无法弄清楚如何运行它。
我还尝试使用tf.estimator.Estimator.predict()
,如下所示:
# 创建Estimator(应该重新加载最后的检查点,但实际上并没有)sunspot_classifier = tf.estimator.Estimator( model_fn=cnn_model_fn, model_dir="./model")# 设置预测的日志记录# 记录标签为"probabilities"的"Softmax"张量的值tensors_to_log = {"probabilities": "softmax_tensor"}logging_hook = tf.train.LoggingTensorHook( tensors=tensors_to_log, every_n_iter=50)# 使用模型进行预测并打印结果pred_input_fn = tf.estimator.inputs.numpy_input_fn( x={"x": pred_data}, shuffle=False)pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)print(pred_results)
但结果只是输出了<generator object Estimator.predict at 0x10dda6bf8>
。然而,如果我使用相同的代码但改用tf.estimator.Estimator.evaluate()
,它就能完美运行(重新加载模型,执行评估并发送到TensorBoard)。
我知道有很多类似的问题,但我实在找不到对我有效的方法。
回答:
sunspot_classifier.predict(input_fn=pred_input_fn)
返回的是生成器。因此,pred_results
是一个生成器对象。要从中获取值,你需要通过next(pred_results)
来迭代它
解决方案是print(next(pred_results))