我对机器学习来说还是一个初学者。为了学习,我正在尝试开发一个简单的CNN来分类棋子。这个网络已经可以运行,我也可以训练它,但我在验证函数上遇到了问题。
我无法将我的预测与target_data
进行比较,因为我的预测只是一个大小为13的张量,而target.data
是[batch_size]x13
。我找不出我的错误在哪里。PyTorch的示例几乎都使用这个函数来比较预测和目标数据。
如果有人能帮我解决这个问题,将非常感激。
您可以在这里查看剩余的代码:https://github.com/michaelwolz/ChessML/blob/master/train.ipynb
def validate(model, validation_data, criterion): model.eval() loss = 0 correct = 0 for i in range(len(validation_data)): data, target = validation_data[i][0], validation_data[i][1] target = torch.Tensor(target) if torch.cuda.is_available(): data = data.cuda() target = target.cuda() out = model(data) loss += criterion(out, target).item() _, prediction = torch.max(out.data, 1) correct += (prediction == target.data).sum().item() loss = loss / len(validation_data) print("###################################") print("平均损失:", loss) print("准确率:", 100. * correct / len(validation_data)) print("###################################")
错误:
<ipython-input-6-6b21e2bfb8a6> in validate(model, validation_data, criterion) 17 18 _, prediction = torch.max(out.data, 1)---> 19 correct += (prediction == target.data).sum().item() 20 21 loss = loss / len(validation_data)RuntimeError: The size of tensor a (32) must match the size of tensor b (13) at non-singleton dimension 1
编辑:我的标签看起来像这样:
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
每个索引代表一个类别。
torch.max()
函数的输出似乎是类的索引。我不明白如何将索引与target_label
进行比较。我的意思是,我可以编写一个函数来检查预测的索引处是否有一个1,但我认为我的错误在其他地方。
回答:
简单地对目标也运行”argmax”:
_, target = torch.max(target.data, 1)
或者更好的是,将目标保持为[example_1_class, example_2_class, ...]
,而不是使用1-hot编码。