在GPU上预训练的Keras模型能否在仅有CPU的电脑上用于预测?

我目前正在学习Keras。我的问题是,如果我在一台高端GPU(如RTX Titan)上训练了一个模型,是否可以导出该模型,加载到另一台低端机器上的新程序中,然后仅使用CPU进行预测。这是可能的吗?

从理论上讲,我认为这就是机器学习的工作方式。模型在高端GPU上训练,一旦导出,无论目标机器是否有GPU,都可以加载并用于进行预测。

如果不是这样,那么机器学习模型是否部署在配备多GPU的高端服务器上?


回答:

在Keras中,这是可以无缝工作的。Keras使用TensorFlow后端会检查GPU是否可用,如果可用,模型将在GPU上训练。

同样,在进行推理时加载模型,如果没有GPU可用,它将使用CPU。

使用Google Colab进行实验

让我们启动一个使用“GPU”运行时的Google Colab

import numpy as npfrom keras.models import Sequentialfrom keras.layers import Denseimport tensorflow as tf tf.compat.v1.debugging.set_log_device_placement(True)print(tf.config.list_physical_devices('GPU'))model = Sequential()model.add(Dense(1024, input_dim=8, activation='relu'))model.add(Dense(1, activation='sigmoid'))model.compile(loss='mse', optimizer='adam', metrics=['accuracy'])X = np.random.randn(10,8) y = np.random.randn(10) model.fit(X, y, epochs=2)model.save("model.h5")

输出

[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]Epoch 1/21/1 [==============================] - 0s 1ms/step - loss: 0.6570 - accuracy: 0.0000e+00Epoch 2/21/1 [==============================] - 0s 983us/step - loss: 0.6242 - accuracy: 0.0000e+00<tensorflow.python.keras.callbacks.History at 0x7fcad09366a0>

所以在这种情况下,模型是在可用的GPU上训练的。你可以使用命令!nvidia-smi看到它占用了GPU。我们已经将模型保存为model.h5。让我们下载它并制作一个本地副本

现在让我们将Colab的运行时更改为“CPU”。让我们将我们的model.h5上传到Colab并进行预测。

import numpy as npfrom keras.models import Sequentialfrom keras.layers import Denseimport tensorflow as tf from keras.models import load_modeltf.compat.v1.debugging.set_log_device_placement(True)print(tf.config.list_physical_devices('GPU'))model = load_model('model.h5')model.predict(X)

输出:

[]array([[0.4464949 ],       [0.43229908],       [0.49823508],       [0.4367126 ],       [0.47648385],       [0.48096564],       [0.47863394],       [0.5031184 ],       [0.45698297],       [0.45885688]], dtype=float32)

正如你所见,预期没有GPU可用,模型已加载并在CPU上运行了预测。

Keras的这种方式是无缝的。但在PyTorch中不同,我们必须手动将模型从GPU移动到CPU。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注