如何重写TensorFlow图以便所有操作都在CPU上运行

我已经在一个多GPU和CPU设置上训练了一个网络,并将结果模型保存为TensorFlow的SavedModel。然后我有另一个脚本,可以加载结果模型并运行所需的操作来进行预测,即在模型上进行推理。这在与模型训练相同的设置上可以正常工作。

然而,我需要将模型部署到只有1个CPU且没有GPU的设备上。当我尝试运行相同的脚本时,我得到了以下错误:

InvalidArgumentError(请参见上面的跟踪信息):从检查点恢复失败。这很可能是由于当前图和检查点图之间的不匹配引起的。请确保您没有更改基于检查点的预期图。原始错误:

无法为操作default_policy_1/tower_1/Variable分配设备:节点default_policy_1/tower_1/Variable(在restore.py:56定义)明确分配给了/device:GPU:0,但可用的设备是[/job:localhost/replica:0/task:0/device:CPU:0, /job:localhost/replica:0/task:0/device:XLA_CPU:0]。请确保设备规格指的是有效的设备。请求的设备似乎是GPU,但CUDA未启用。[[节点default_policy_1/tower_1/Variable(在restore.py:56定义)]]

这个看起来有希望,但代码并没有改变我的图,一点节点都没有被移除 – Remove operation graph tensorflow to run on CPU

总的来说,似乎不应该简单地移除所有不能在CPU上运行的操作

我尝试过将所有内容包装在with tf.device('CPU:0')块中,以及使用config = tf.ConfigProto(device_count={'GPU': 0}),但两者都没有改变错误。

相关代码:

from tensorflow.python.saved_model import loaderinput_tensor_key_feed_dict = {'observations': np.array([[23]]), 'prev_action': np.array([0]),                              'prev_reward': np.array([0]), 'is_training': False}config = tf.ConfigProto(device_count={'GPU': 0})with tf.device('CPU:0'):    with session.Session(None, graph=ops_lib.Graph(), config=config) as sess:                loader.load(sess, tag_set.split(','), saved_model_dir) #错误在这里发生                outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)        for i, output in enumerate(outputs):            output_tensor_key = output_tensor_keys_sorted[i]            print('Result for output key %s:\t%s' % (output_tensor_key, output))

回答:

我会初始化一个没有设备规格的新模型,然后只加载模型变量,就像它是一个标准的训练检查点一样,使用tf.Saver()。在这一点上,您应该能够保存一个版本的SavedModel,TensorFlow可以决定在哪里放置操作。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注