如何使用tf.estimator.Estimator从检查点进行预测?

我刚刚使用TensorFlow训练了一个用于识别太阳黑子的CNN模型。我的模型与这个非常相似。问题在于,我找不到任何关于如何使用训练阶段生成的检查点进行预测的清晰说明。

我尝试使用标准的恢复方法:

saver = tf.train.import_meta_graph('./model/model.ckpt.meta')saver.restore(sess,'./model/model.ckpt')

但随后我无法弄清楚如何运行它。
我还尝试使用tf.estimator.Estimator.predict(),如下所示:

# 创建Estimator(应该重新加载最后的检查点,但实际上并没有)sunspot_classifier = tf.estimator.Estimator(    model_fn=cnn_model_fn, model_dir="./model")# 设置预测的日志记录# 记录标签为"probabilities"的"Softmax"张量的值tensors_to_log = {"probabilities": "softmax_tensor"}logging_hook = tf.train.LoggingTensorHook(    tensors=tensors_to_log, every_n_iter=50)# 使用模型进行预测并打印结果pred_input_fn = tf.estimator.inputs.numpy_input_fn(    x={"x": pred_data},    shuffle=False)pred_results = sunspot_classifier.predict(input_fn=pred_input_fn)print(pred_results)

但结果只是输出了<generator object Estimator.predict at 0x10dda6bf8>。然而,如果我使用相同的代码但改用tf.estimator.Estimator.evaluate(),它就能完美运行(重新加载模型,执行评估并发送到TensorBoard)。

我知道有很多类似的问题,但我实在找不到对我有效的方法。


回答:

sunspot_classifier.predict(input_fn=pred_input_fn)返回的是生成器。因此,pred_results是一个生成器对象。要从中获取值,你需要通过next(pred_results)来迭代它

解决方案是print(next(pred_results))

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注