无效的参数组合 – eq()

我使用这里共享的代码来测试一个CNN图像分类器。当我调用测试函数时,在第155行遇到了这个错误:

test_acc += torch.sum(prediction == labels.data)TypeError: eq() received an invalid combination of arguments - got (numpy.ndarray), but expected one of: * (Tensor other)      didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m) * (Number other)      didn't match because some of the arguments have invalid types: ([31;1mnumpy.ndarray[0m)

test函数的片段:

def test():    model.eval()    test_acc = 0.0    for i, (images, labels) in enumerate(test_loader):        if cuda_avail:                images = Variable(images.cuda())                labels = Variable(labels.cuda())        #使用测试集中的图像预测类别        outputs = model(images)        _,prediction = torch.max(outputs.data, 1)        prediction = prediction.cpu().numpy()        test_acc += torch.sum(prediction == labels.data) #line 155    #计算所有10000个测试图像的平均准确率和损失    test_acc = test_acc / 10000return test_acc

经过快速搜索,我发现这个错误可能与predictionlabels之间的比较有关,就像这个Stack Overflow问题中提到的那样。

有什么解决方法吗?


回答:

为什么在这里prediction = prediction.cpu().numpy()使用.numpy()?这样做会将PyTorch张量转换为NumPy数组,使其无法与labels.data进行比较,因为类型不兼容。

去掉.numpy()部分应该可以解决这个问题。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注