RuntimeError: 输入类型(torch.FloatTensor)和权重类型(torch.cuda.FloatTensor)应相同

如下代码会导致错误:

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)for data in dataloader:    inputs, labels = data    outputs = model(inputs)

引发的错误是:

RuntimeError: 输入类型(torch.FloatTensor)和权重类型(torch.cuda.FloatTensor)应相同


回答:

您会遇到这个错误,因为您的模型在GPU上,而数据在CPU上。因此,您需要将输入张量发送到GPU上。

inputs, labels = data                         # 这是您原有的代码inputs, labels = inputs.cuda(), labels.cuda() # 添加这一行

或者,为了与您的其他代码保持一致,可以这样做:

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")inputs, labels = inputs.to(device), labels.to(device)

如果您的输入张量在GPU上而模型权重不在GPU上,也会引发相同的错误。在这种情况下,您需要将模型权重发送到GPU上。

model = MyModel()if torch.cuda.is_available():    model.cuda()

请参阅文档了解cuda(),以及它的反操作cpu()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注