为什么我的简单PyTorch网络在GPU设备上无法工作?

我按照教程构建了一个简单的网络,但遇到了以下错误:

RuntimeError: 期望类型为torch.cuda.FloatTensor的对象,但为参数#4 ‘mat1’找到了类型torch.FloatTensor

有什么帮助吗?谢谢!


回答:

TL;DR
这是修复方法

inputs = inputs.to(device)  

为什么?!
torch.nn.Module.to()torch.Tensor.to() 之间存在细微的差异:虽然 Module.to() 是一个原地操作符,但 Tensor.to() 不是。因此

net.to(device)

会改变 net 本身并将其移动到 device。另一方面

inputs.to(device)

不会改变 inputs,而是返回一个位于 device 上的 inputs副本。为了使用这个“在设备上”的副本,你需要将其赋值给一个变量,因此

inputs = inputs.to(device)

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注