PyTorch模型验证:张量a的大小(32)必须与张量b的大小(13)匹配

我对机器学习来说还是一个初学者。为了学习,我正在尝试开发一个简单的CNN来分类棋子。这个网络已经可以运行,我也可以训练它,但我在验证函数上遇到了问题。

我无法将我的预测与target_data进行比较,因为我的预测只是一个大小为13的张量,而target.data[batch_size]x13。我找不出我的错误在哪里。PyTorch的示例几乎都使用这个函数来比较预测和目标数据。

如果有人能帮我解决这个问题,将非常感激。

您可以在这里查看剩余的代码:https://github.com/michaelwolz/ChessML/blob/master/train.ipynb

def validate(model, validation_data, criterion):    model.eval()    loss = 0    correct = 0    for i in range(len(validation_data)):        data, target = validation_data[i][0], validation_data[i][1]        target = torch.Tensor(target)        if torch.cuda.is_available():            data = data.cuda()            target = target.cuda()        out = model(data)        loss += criterion(out, target).item()        _, prediction = torch.max(out.data, 1)        correct += (prediction == target.data).sum().item()    loss = loss / len(validation_data)    print("###################################")    print("平均损失:", loss)    print("准确率:", 100. * correct / len(validation_data))    print("###################################")

错误:

<ipython-input-6-6b21e2bfb8a6> in validate(model, validation_data, criterion)     17      18         _, prediction = torch.max(out.data, 1)---> 19         correct += (prediction == target.data).sum().item()     20      21     loss = loss / len(validation_data)RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 1

编辑:我的标签看起来像这样:

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

每个索引代表一个类别。
torch.max()函数的输出似乎是类的索引。我不明白如何将索引与target_label进行比较。我的意思是,我可以编写一个函数来检查预测的索引处是否有一个1,但我认为我的错误在其他地方。


回答:

简单地对目标也运行”argmax”:

_, target = torch.max(target.data, 1)

或者更好的是,将目标保持为[example_1_class, example_2_class, ...],而不是使用1-hot编码。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注