我使用了高级的tf API,如tf.estimator,创建了一个网络。
训练和评估都能正常工作并产生输出。然而,在对新数据进行预测时,get_inputs()
需要label_data
和batch_size
。
错误信息是:TypeError: get_inputs() missing 2 required positional arguments: 'label_data' and 'batch_size'
如何解决这个问题以便进行预测?
这是我的代码:
predictTest = [0.34, 0.65, 0.88]
predictTest只是一个测试,不会是我真正的预测数据。
错误是在get_inputs()
中抛出的。
def get_inputs(feature_data, label_data, batch_size, n_epochs=None, shuffle=True): dataset = tf.data.Dataset.from_tensor_slices( (feature_data, label_data)) dataset = dataset.repeat(n_epochs) if shuffle: dataset = dataset.shuffle(len(feature_data)) dataset = dataset.batch(batch_size) features, labels = dataset.make_one_shot_iterator().get_next() return features, labels
预测输入:
def predict_input_fn(): return get_inputs( predictTest, n_epochs=1, shuffle=False )
预测:
predict = estimator.predict(predict_input_fn)print("Prediction: {}".format(list(predict)))
回答:
我发现必须为预测创建一个新的get_inputs()
函数。
如果我使用训练和评估使用的get_inputs()
,它会期待一些它不会得到的数据。
get_inputs
:
def get_inputs(feature_data, label_data, batch_size, n_epochs=None, shuffle=True): dataset = tf.data.Dataset.from_tensor_slices( #from_tensor_slices (feature_data, label_data)) dataset = dataset.repeat(n_epochs) if shuffle: dataset = dataset.shuffle(len(feature_data)) dataset = dataset.batch(batch_size) features, labels = dataset.make_one_shot_iterator().get_next() return features, labels
创建一个新的函数,名为pred_get_inputs,不需要label_data
或batch_size
:
def get_pred_inputs(feature_data,n_epochs=None, shuffle=False): dataset = tf.data.Dataset.from_tensor_slices( #from_tensor_slices (feature_data)) dataset = dataset.repeat(n_epochs) if shuffle: dataset = dataset.shuffle(len(feature_data)) dataset = dataset.batch(1) features = dataset return features