我使用这里共享的代码来测试一个CNN图像分类器。当我调用测试函数时,在第155行遇到了这个错误:
test_acc += torch.sum(prediction == labels.data)TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of: * (Tensor other) didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m) * (Number other) didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m)
test
函数的片段:
def test(): model.eval() test_acc = 0.0 for i, (images, labels) in enumerate(test_loader): if cuda_avail: images = Variable(images.cuda()) labels = Variable(labels.cuda()) #使用测试集中的图像预测类别 outputs = model(images) _,prediction = torch.max(outputs.data, 1) prediction = prediction.cpu().numpy() test_acc += torch.sum(prediction == labels.data) #line 155 #计算所有10000个测试图像的平均准确率和损失 test_acc = test_acc / 10000return test_acc
经过快速搜索,我发现这个错误可能与prediction
和labels
之间的比较有关,就像这个Stack Overflow问题中提到的那样。
有什么解决方法吗?
回答:
为什么在这里prediction = prediction.cpu().numpy()
使用.numpy()
?这样做会将PyTorch张量转换为NumPy数组,使其无法与labels.data
进行比较,因为类型不兼容。
去掉.numpy()
部分应该可以解决这个问题。