加载和使用已保存的BoostedTreesClassifier模型

我正在尝试在TensorFlow中使用一个已保存的BoostedTreesClassifier模型,但无法弄清楚如何使用加载的模型进行预测。我使用的是教程中的示例代码。这是我在使用的简化代码:

import pandas as pdfrom matplotlib import pyplot as pltimport tensorflow as tftf.random.set_seed(123)dftrain = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/train.csv')dfeval = pd.read_csv('https://storage.googleapis.com/tf-datasets/titanic/eval.csv')y_train = dftrain.pop('survived')y_eval = dfeval.pop('survived')CATEGORICAL_COLUMNS = ['sex', 'n_siblings_spouses', 'parch', 'class', 'deck',                       'embark_town', 'alone']NUMERIC_COLUMNS = ['age', 'fare']def one_hot_cat_column(feature_name, vocab):  return tf.feature_column.indicator_column(      tf.feature_column.categorical_column_with_vocabulary_list(feature_name,                                                 vocab))feature_columns = []for feature_name in CATEGORICAL_COLUMNS:  # 需要对分类特征进行独热编码。  vocabulary = dftrain[feature_name].unique()  feature_columns.append(one_hot_cat_column(feature_name, vocabulary))for feature_name in NUMERIC_COLUMNS:  feature_columns.append(tf.feature_column.numeric_column(feature_name,                                           dtype=tf.float32))NUM_EXAMPLES = len(y_train)def make_input_fn(X, y, n_epochs=None, shuffle=True):  def input_fn():    dataset = tf.data.Dataset.from_tensor_slices((dict(X), y))    if shuffle:      dataset = dataset.shuffle(NUM_EXAMPLES)    # 对于训练,需要根据需要多次循环数据集(n_epochs=None)。    dataset = dataset.repeat(n_epochs)    # 内存训练不使用批处理。    dataset = dataset.batch(NUM_EXAMPLES)    return dataset  return input_fntrain_input_fn = make_input_fn(dftrain, y_train)eval_input_fn = make_input_fn(dfeval, y_eval, shuffle=False, n_epochs=1)n_batches = 1est = tf.estimator.BoostedTreesClassifier(feature_columns,                                          n_batches_per_layer=n_batches)est.train(train_input_fn, max_steps=100)result = est.evaluate(eval_input_fn)# 进行预测pred_dicts = list(est.predict(eval_input_fn))probs = pd.Series([pred['probabilities'][1] for pred in pred_dicts])probs.plot(kind='hist', bins=20, title='predicted probabilities')plt.show()# 保存模型feature_spec = tf.feature_column.make_parse_example_spec(feature_columns)serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)saved_model_path = est.export_saved_model('saved_model', serving_input_receiver_fn)# 加载模型loaded_est = tf.saved_model.load(saved_model_path)# 如何使用loaded_est进行预测?# pred_dicts_using_loaded_model = list(loaded_est.predict(eval_input_fn))# probs_using_loaded_model = pd.Series([pred['probabilities'][1] for pred in pred_dicts_using_loaded_model])

编辑:最后两行注释的代码只是为了展示我最终想要实现的目标。它们无法正确运行,也不是为了运行,因为loaded_est是一种与est不同的对象。我不知道如何像使用est那样使用loaded_est进行预测。我查看了保存和加载模型的文档这里,他们在图像上进行了操作,但我无法将其转换为此数据,其中输入只是一个向量(即dfeval数据框中的一行)。


回答:

我最终在medium.com上找到了由Ajeet Singh贡献的示例代码,通过搜索稍微更通用的内容找到了。为了回答我自己的问题,以下是使用加载的模型进行预测的代码:

def predict(loaded_model, row, columns, dtypes):  example = tf.train.Example()  for i in range(len(columns)):    if dtypes[i] == 'object':      example.features.feature[columns[i]].bytes_list.value.extend([bytes(row[i], 'utf-8')])    elif dtypes[i] == 'float':      example.features.feature[columns[i]].float_list.value.extend([row[i]])    elif dtypes[i] == 'int64':      example.features.feature[columns[i]].int64_list.value.extend([row[i]])  return loaded_model.signatures['predict'](examples=tf.constant([example.SerializeToString()]))pred_dicts_using_loaded_model = [predict(loaded_est, row, dfeval.columns, dfeval.dtypes) for row in dfeval.itertuples(index=False)]probs_using_loaded_model = [pred_dict['probabilities'][0][1].numpy() for pred_dict in pred_dicts_using_loaded_model]for p in zip(probs, probs_using_loaded_model):  print(p)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注