问题陈述:我有一张图片,图片中的一个像素只能属于Band5','Band6', 'Band7'
中的一个(详情见下文)。因此,我有一个PyTorch多类问题,但我无法理解如何设置目标,这些目标需要是[batch, w, h]
的形式。
我的数据加载器返回两个值:
x = chips.loc[:, :, :, self.input_bands] y = chips.loc[:, :, :, self.output_bands] x = x.transpose('chip','channel','x','y')y_ohe = y.transpose('chip','channel','x','y')
此外,我定义了:
input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4'] # 输入类output_bands = ['Band5','Band6', 'Band7'] #目标类model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)loss_new = nn.CrossEntropyLoss()
在我的训练函数中:
#从数据加载器获取值 X = normalize_zero_to_one(X) #输入 y = normalize_zero_to_one(y) #目标 images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W] masks = Variable(torch.from_numpy(y)).to(device) optim.zero_grad() outputs = model(images) loss = loss_new(outputs, masks) # (preds, target) loss.backward() optim.step() # 更新权重
我知道目标(这里是masks
)应该是[batch_size, w, h]
。然而,目前它是[batch_size, channels, w, h]
。
我阅读了很多帖子,包括1,2,它们都说目标应该只包含目标类索引
。我不明白如何将三个类的索引连接起来,同时仍然将目标设置为[batch_size, w, h]
。
目前,我收到的错误是:
RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4
据我所知,我不需要进行任何独热编码。在互联网上找到的类似错误和解释如下:
任何帮助都将不胜感激!谢谢你。
回答:
如果我理解正确的话,你当前的“目标”是[batch_size, channels, w, h]
,其中channels==3
因为你有三个可能的目标。
你的目标中的值代表什么?你基本上是为每个像素提供了一个3向量目标 – 这些是预期的类概率吗?它们是指示正确“波段”的“独热向量”吗?如果是的话,你可以通过在目标通道维度上简单地取argmax
来获取目标索引:
proper_target = torch.argmax(masks, dim=1) # 确保keepdim=False loss = loss_new(outputs, proper_target)