在神经网络训练中,如何正确地最大化一个损失函数并最小化另一个损失函数?

我有一个简单的NN:

import torchimport torch.nn as nnimport torch.optim as optimclass Model(nn.Module):    def __init__(self):        super(Model, self).__init__()        self.fc1 = nn.Linear(1, 5)        self.fc2 = nn.Linear(5, 10)        self.fc3 = nn.Linear(10, 1)    def forward(self, x):        x = self.fc1(x)        x = torch.relu(x)                x = torch.relu(self.fc2(x))        x = self.fc3(x)        return xnet = Model()opt = optim.Adam(net.parameters())

我还有一些输入特征:

features = torch.rand((3,1)) 

我可以用一个简单的损失函数进行正常训练,该函数将被最小化:

for i in range(10):    opt.zero_grad()    out = net(features)    loss = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))    print('loss:', loss)    loss.backward()    opt.step()

然而,如果我添加另一个想要最大化的损失组件loss2:

loss2s = []for i in range(10000):    opt.zero_grad()    out = net(features)    loss1 = torch.mean(torch.square(torch.tensor(5) - torch.sum(out)))    loss2 = torch.sum(torch.tensor([torch.sum(w_arr) for w_arr in net.parameters()]))    loss2s.append(loss2)    loss = loss1 + loss2    loss.backward()    opt.step()

由于两个损失的尺度不同,似乎变得不稳定。此外,我不确定这是正确的方法,因为损失函数如何知道要最大化一部分并最小化另一部分。请注意,这只是一个例子,显然增加权重没有意义。

import matplotlib.pyplot as pltplt.plot(loss2s, c='r')plt.plot(loss1s, c='b')

enter image description here

而且我相信在机器学习中,最小化函数是常见的训练方式,所以我不确定将最大化问题转换成某种形式的最小化问题是否会更好。


回答:

表示“最小化”和“最大化”的标准方法是改变符号。PyTorch总是最小化一个loss,如果执行以下操作:

loss.backward()

因此,如果需要最大化另一个loss2,我们添加其负值:

overall_loss = loss + (- loss2)overall_loss.backward()

因为最小化一个负值等同于最大化原始的正值。

关于“尺度”,是的,尺度确实很重要。通常为了匹配尺度,会执行以下操作:

overall_loss = loss + alpha * (- loss2)

其中alpha是一个分数,表示一个损失相对于另一个损失的相对重要性。这是一个超参数,需要进行实验。


撇开技术细节不谈,结果损失是否稳定在很大程度上取决于具体问题和所涉及的损失函数。如果损失是矛盾的,你可能会经历不稳定性。处理这些问题的方法本身就是一个研究问题,远远超出了这个问题的范围。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注