估计器:在准确性预测中应使用什么占位符?

我在TensorFlow中使用DNNClassifierLinearClassifier(为了比较使用两个不同的模型)对我的数据集进行二元分类。我已经成功地使这两个模型运行并输出了关于准确性的指标。然而,我希望能够得到一个包含模型对测试数据所有预测的数组。理想情况下,这些预测还应该有一些相关的概率值。

在StackOverflow和TensorFlow的GitHub页面上搜索后,我找到了这些代码行来提取预测:

sess = tf.Session()sess.run(tf.global_variables_initializer())sess.run(results, feed_dict={?})

唯一的问题是我不知道在feed_dict中应该放什么。大多数其他提问者似乎没有使用预制的估计器,因此我对在自己的代码中找到他们放入feed_dict的等价物感到困惑。我尝试使用我在test_input_fn方法中放入的内容,但得到了一个TypeError。我的代码中的其他关键部分如下所示:

def get_features(array):    return {'policy_state': np.array(array[:, 2], dtype=str),       'tiv': np.array(array[:, 4], dtype=int),       'veh_count': np.array(array[:, 5], dtype=int),       'generation': np.array(array[:, 8], dtype=str),       'modern_classic_ind': np.array(array[:, 17], dtype=str),       'h_plus_ind': np.array(array[:, 7], dtype=str)}train_input_fn = tf.estimator.inputs.numpy_input_fn(    x=get_features(my_data),    y=np.array(np.array(my_data[:, 11], dtype=int)),    num_epochs=None,    shuffle=True)test_input_fn = tf.estimator.inputs.numpy_input_fn(    x=get_features(test_data),    y=np.array(np.array(test_data[:, 11], dtype=int)),    num_epochs=None,    shuffle=True)

如果需要,我很乐意提供额外的代码/信息


回答:

使用tf.estimator.Estimator.predict方法,该方法返回一个预测结果的生成器。示例:

predict_input_fn = tf.estimator.inputs.numpy_input_fn({"x": x_test},                                                      batch_size=batch_size,                                                      num_epochs=1,                                                      shuffle=False)predictions = estimator.predict(input_fn=predict_input_fn)for prediction in predictions:  print(prediction)

注意num_epochs=1shuffle=False以使结果稳定。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注