如何在PyTorch多类问题中设置交叉熵损失的目标

问题陈述:我有一张图片,图片中的一个像素只能属于Band5','Band6', 'Band7'中的一个(详情见下文)。因此,我有一个PyTorch多类问题,但我无法理解如何设置目标,这些目标需要是[batch, w, h]的形式。

我的数据加载器返回两个值:

x = chips.loc[:, :, :, self.input_bands]     y = chips.loc[:, :, :, self.output_bands]        x = x.transpose('chip','channel','x','y')y_ohe = y.transpose('chip','channel','x','y')

此外,我定义了:

input_bands = ['Band1','Band2', 'Band3', 'Band3', 'Band4']  # 输入类output_bands = ['Band5','Band6', 'Band7'] #目标类model = ModelName(num_classes = 3, depth=default_depth, in_channels=5, merge_mode='concat').to(device)loss_new = nn.CrossEntropyLoss()

在我的训练函数中:

        #从数据加载器获取值        X = normalize_zero_to_one(X) #输入        y = normalize_zero_to_one(y) #目标        images = Variable(torch.from_numpy(X)).to(device) # [batch, channel, H, W]        masks = Variable(torch.from_numpy(y)).to(device)         optim.zero_grad()                outputs = model(images)         loss = loss_new(outputs, masks) # (preds, target)        loss.backward()                 optim.step() # 更新权重  

我知道目标(这里是masks)应该是[batch_size, w, h]。然而,目前它是[batch_size, channels, w, h]

我阅读了很多帖子,包括12,它们都说目标应该只包含目标类索引。我不明白如何将三个类的索引连接起来,同时仍然将目标设置为[batch_size, w, h]

目前,我收到的错误是:

RuntimeError: only batches of spatial targets supported (3D tensors) but got targets of dimension: 4

据我所知,我不需要进行任何独热编码。在互联网上找到的类似错误和解释如下:

任何帮助都将不胜感激!谢谢你。


回答:

如果我理解正确的话,你当前的“目标”是[batch_size, channels, w, h],其中channels==3因为你有三个可能的目标。
你的目标中的代表什么?你基本上是为每个像素提供了一个3向量目标 – 这些是预期的类概率吗?它们是指示正确“波段”的“独热向量”吗?如果是的话,你可以通过在目标通道维度上简单地取argmax来获取目标索引:

proper_target = torch.argmax(masks, dim=1)  # 确保keepdim=False loss = loss_new(outputs, proper_target)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注