Tensorflow – 停止恢复网络参数

我在尝试使用Tensorflow网络进行多次连续预测,但即使在CPU上,性能似乎也非常差(对于一个2层8×8的卷积网络,每次预测约需500毫秒)。我怀疑问题的一部分是每次似乎都在重新加载网络参数。下面的代码中每次调用classifier.predict都会产生以下输出行,因此我会看到数百次这样的输出。

INFO:tensorflow:Restoring parameters from /tmp/model_data/model.ckpt-102001

如何重用已经加载的检查点?

(我在这里不能进行批量预测,因为网络的输出是一个在游戏中需要应用到当前状态的动作,然后再将新的游戏状态输入网络。)

以下是进行预测的循环代码。

def rollout(classifier, state):  while not state.terminated:    predict_input_fn = tf.estimator.inputs.numpy_input_fn(x={"x": state.as_nn_input()}, shuffle=False)    prediction = next(classifier.predict(input_fn=predict_input_fn))    index = np.random.choice(NUM_ACTIONS, p=prediction["probabilities"]) # 根据网络输出的概率选择一个动作    state.apply_move(index)

classifier 是一个用以下代码创建的 tf.estimator.Estimator

classifier = tf.estimator.Estimator(      model_fn=cnn_model_fn, model_dir=os.path.join(tempfile.gettempdir(), 'model_data'))

回答:

Estimator API 是一个高级API。

tf.estimator框架通过其高级Estimator API使得构建和训练机器学习模型变得简单。Estimator提供了一些类,您可以实例化这些类来快速配置常见的模型类型,如回归器和分类器。

Estimator API抽象了TensorFlow的大量复杂性,但在此过程中也失去了一些通用性。阅读了代码后,很明显没有办法在不每次重新加载模型的情况下运行多个连续预测。低级TensorFlow API允许这种行为。但是…

Keras 是一个支持这种用例的高级框架。只需定义模型,然后重复调用predict即可。

def rollout(model, state):  while not state.terminated:    predictions = model.predict(state.as_nn_input())    for _, prediction in enumerate(predictions):      index = np.random.choice(bt.ACTIONS, p=prediction)      state.apply_mode(index)

非科学的基准测试显示,这大约快了100倍。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注