PyTorch模型验证:张量a的大小(32)必须与张量b的大小(13)匹配

我对机器学习来说还是一个初学者。为了学习,我正在尝试开发一个简单的CNN来分类棋子。这个网络已经可以运行,我也可以训练它,但我在验证函数上遇到了问题。

我无法将我的预测与target_data进行比较,因为我的预测只是一个大小为13的张量,而target.data[batch_size]x13。我找不出我的错误在哪里。PyTorch的示例几乎都使用这个函数来比较预测和目标数据。

如果有人能帮我解决这个问题,将非常感激。

您可以在这里查看剩余的代码:https://github.com/michaelwolz/ChessML/blob/master/train.ipynb

def validate(model, validation_data, criterion):    model.eval()    loss = 0    correct = 0    for i in range(len(validation_data)):        data, target = validation_data[i][0], validation_data[i][1]        target = torch.Tensor(target)        if torch.cuda.is_available():            data = data.cuda()            target = target.cuda()        out = model(data)        loss += criterion(out, target).item()        _, prediction = torch.max(out.data, 1)        correct += (prediction == target.data).sum().item()    loss = loss / len(validation_data)    print("###################################")    print("平均损失:", loss)    print("准确率:", 100. * correct / len(validation_data))    print("###################################")

错误:

<ipython-input-6-6b21e2bfb8a6> in validate(model, validation_data, criterion)     17      18         _, prediction = torch.max(out.data, 1)---> 19         correct += (prediction == target.data).sum().item()     20      21     loss = loss / len(validation_data)RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 1

编辑:我的标签看起来像这样:

[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

每个索引代表一个类别。
torch.max()函数的输出似乎是类的索引。我不明白如何将索引与target_label进行比较。我的意思是,我可以编写一个函数来检查预测的索引处是否有一个1,但我认为我的错误在其他地方。


回答:

简单地对目标也运行”argmax”:

_, target = torch.max(target.data, 1)

或者更好的是,将目标保持为[example_1_class, example_2_class, ...],而不是使用1-hot编码。

Related Posts

关于k折交叉验证的直观问题

我在使用交叉验证检查预测能力时遇到了一些直观问题,我认…

调整numpy数组大小以使用sklearn的train_test_split函数?

我正在尝试使用sklearn中的test_train_…

如何转换二维张量和索引张量以便用于torch.nn.utils.rnn.pack_sequence

我有一组序列,格式如下: sequences = to…

模型预测值的含义是什么?

我在网上找到一个数字识别器的CNN模型并进行了训练,当…

锯齿张量作为LSTM的输入

了解锯齿张量以及如何在TensorFlow中使用它们。…

如何告诉SciKit的LinearRegression模型预测值不能小于零?

我有以下代码,尝试根据非价格基础特征来估值股票。 pr…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注