我在尝试使用PyTorch创建的模型来查找准确率时遇到了错误。最初我遇到了另一个错误,已经修复了,但现在出现了这个错误。
我使用以下代码获取测试集:
testset = torchvision.datasets.FashionMNIST(MNIST_DIR, train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1), torchvision.transforms.ToTensor(), # 图像转为张量 torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 图像,标签 ]))testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False)
当我尝试访问创建的测试集时,出于某种原因,它试图重新训练模型,然后继续报错。这是我获取准确率并调用测试集的代码:
correct = 0total = 0with torch.no_grad(): print("entered here") for (x, y_gt) in testloader: x = x.to(device) y_gt = y_gt.to(device) outputs = teacher_model(x) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))
这是我遇到的错误:
Traceback (most recent call last): File "[path]/train_teacher_1.py", line 134, in <module> outputs = teacher_model(x) File "[path]\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "[path]\models.py", line 17, in forward x = F.relu(self.layer1(x)) File "[path]\anaconda3\lib\site-packages\torch\nn\modules\module.py", line 727, in _call_impl result = self.forward(*input, **kwargs) File "[path]\anaconda3\lib\site-packages\torch\nn\modules\linear.py", line 93, in forward return F.linear(input, self.weight, self.bias) File "[path]\anaconda3\lib\site-packages\torch\nn\functional.py", line 1692, in linear output = input.matmul(weight.t())RuntimeError: mat1 dim 1 must match mat2 dim 0
如果您需要查看训练模型的其余代码,请告诉我。我没有包括这些代码,因为帖子已经太长了。
我是PyTorch的新手,任何帮助都将不胜感激。提前感谢。
回答:
我已经解决了这个问题,我需要检查x的大小。
我在for循环中添加了以下代码来修复它:x = torch.flatten(x, start_dim=1, end_dim=-1)