我正在尝试修改从https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/01-basics/feedforward_neural_network/main.py获取的前馈网络,以使用我自己的数据集。
我定义了一个自定义数据集,包含两个一维数组作为输入,以及两个标量作为相应的输出:
x = torch.tensor([[5.5, 3,3,4] , [1 , 2,3,4], [9 , 2,3,4]])print(x)y = torch.tensor([1,2,3])print(y)import torch.utils.data as data_utilsmy_train = data_utils.TensorDataset(x, y)my_train_loader = data_utils.DataLoader(my_train, batch_size=50, shuffle=True)
我已经更新了超参数以匹配新的输入大小(2)和类别数量(3)。
我还将images = images.reshape(-1, 28*28).to(device)
更改为images = images.reshape(-1, 4).to(device)
由于训练集非常小,我将批次大小改为1。
在进行这些修改后,尝试训练时我收到了错误:
RuntimeError Traceback (most recent call> last) <ipython-input-52-9cdca58f3ef6> in <module>()> 51 > 52 # Forward pass> ---> 53 outputs = model(images)> 54 loss = criterion(outputs, labels)> 55 > > /home/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)> 489 result = self._slow_forward(*input, **kwargs)> 490 else:> --> 491 result = self.forward(*input, **kwargs)> 492 for hook in self._forward_hooks.values():> 493 hook_result = hook(self, input, result)> > <ipython-input-52-9cdca58f3ef6> in forward(self, x)> 31 > 32 def forward(self, x):> ---> 33 out = self.fc1(x)> 34 out = self.relu(out)> 35 out = self.fc2(out)> > /home/.local/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)> 489 result = self._slow_forward(*input, **kwargs)> 490 else:> --> 491 result = self.forward(*input, **kwargs)> 492 for hook in self._forward_hooks.values():> 493 hook_result = hook(self, input, result)> > /home/.local/lib/python3.6/site-packages/torch/nn/modules/linear.py in forward(self, input)> 53 > 54 def forward(self, input):> ---> 55 return F.linear(input, self.weight, self.bias)> 56 > 57 def extra_repr(self):> > /home/.local/lib/python3.6/site-packages/torch/nn/functional.py> in linear(input, weight, bias)> 990 if input.dim() == 2 and bias is not None:> 991 # fused op is marginally faster> --> 992 return torch.addmm(bias, input, weight.t())> 993 > 994 output = input.matmul(weight.t())> > RuntimeError: size mismatch, m1: [3 x 4], m2: [2 x 3] at> /pytorch/aten/src/THC/generic/THCTensorMathBlas.cu:249
如何修改代码以匹配预期的维度?我不确定应该更改哪些代码,因为我已经更新了所有需要更新的参数?
更改前的源代码:
...
回答: