PyTorch模型验证:张量a的大小(32)必须与张量b的大小(13)匹配

我对机器学习来说还是一个初学者。为了学习,我正在尝试开发一个简单的CNN来分类棋子。这个网络已经可以运行,我也可以训练它,但我在验证函数上遇到了问题。

我无法将我的预测与target_data进行比较,因为我的预测只是一个大小为13的张量,而target.data[batch_size]x13。我找不出我的错误在哪里。PyTorch的示例几乎都使用这个函数来比较预测和目标数据。

如果有人能帮我解决这个问题,将非常感激。

您可以在这里查看剩余的代码:https://github.com/michaelwolz/ChessML/blob/master/train.ipynb

def validate(model, validation_data, criterion):    model.eval()    loss = 0    correct = 0    for i in range(len(validation_data)):        data, target = validation_data[i][0], validation_data[i][1]        target = torch.Tensor(target)        if torch.cuda.is_available():            data = data.cuda()            target = target.cuda()        out = model(data)        loss += criterion(out, target).item()        _, prediction = torch.max(out.data, 1)        correct += (prediction == target.data).sum().item()    loss = loss / len(validation_data)    print("###################################")    print("平均损失:", loss)    print("准确率:", 100. * correct / len(validation_data))    print("###################################")

错误:

<ipython-input-6-6b21e2bfb8a6> in validate(model, validation_data, criterion)     17      18         _, prediction = torch.max(out.data, 1)---> 19         correct += (prediction == target.data).sum().item()     20      21     loss = loss / len(validation_data)RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 1

编辑:我的标签看起来像这样:

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

每个索引代表一个类别。
torch.max()函数的输出似乎是类的索引。我不明白如何将索引与target_label进行比较。我的意思是,我可以编写一个函数来检查预测的索引处是否有一个1,但我认为我的错误在其他地方。


回答:

简单地对目标也运行”argmax”:

_, target = torch.max(target.data, 1)

或者更好的是,将目标保持为[example_1_class, example_2_class, ...],而不是使用1-hot编码。

Related Posts

在使用k近邻算法时,有没有办法获取被使用的“邻居”?

我想找到一种方法来确定在我的knn算法中实际使用了哪些…

Theano在Google Colab上无法启用GPU支持

我在尝试使用Theano库训练一个模型。由于我的电脑内…

准确性评分似乎有误

这里是代码: from sklearn.metrics…

Keras Functional API: “错误检查输入时:期望input_1具有4个维度,但得到形状为(X, Y)的数组”

我在尝试使用Keras的fit_generator来训…

如何使用sklearn.datasets.make_classification在指定范围内生成合成数据?

我想为分类问题创建合成数据。我使用了sklearn.d…

如何处理预测时不在训练集中的标签

已关闭。 此问题与编程或软件开发无关。目前不接受回答。…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注