Pytorch TypeError: eq() 接收到无效的参数组合

num_samples = 10def predict(x):    sampled_models = [guide(None, None) for _ in range(num_samples)]    yhats = [model(x).data for model in sampled_models]    mean = torch.mean(torch.stack(yhats), 0)    return np.argmax(mean.numpy(), axis=1)print('当网络被强制预测时的预测结果')correct = 0total = 0for j, data in enumerate(test_loader):    images, labels = data    predicted = predict(images.view(-1,28*28))    total += labels.size(0)    correct += (predicted == labels).sum().item()print("准确率: %d %%" % (100 * correct / total))

错误:

correct += (predicted == labels).sum().item() TypeError: eq() 接收到无效的参数组合 - 接收到 (numpy.ndarray),但期望的是以下之一:  * (Tensor other)  不匹配,因为某些参数类型无效:(!numpy.ndarray!)* (Number other)  不匹配,因为某些参数类型无效:(!numpy.ndarray!)

*


回答:

您试图比较 predictedlabels。然而,您的 predicted 是一个 np.array,而 labels 是一个 torch.tensor,因此 eq()(即 == 操作符)无法在它们之间进行比较。
np.argmax 替换为 torch.argmax

 return torch.argmax(mean, dim=1)

这样应该就可以了。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注