为什么在线性层后使用ReLU激活函数会导致准确率下降

我开始使用Pytorch,并在FashionMNIST数据集上构建了一个非常基础的CNN。我在使用神经网络时注意到了一些奇怪的行为,我不知道为什么会发生这种情况,在前向函数中,当我在每个线性层之后使用ReLU函数时,神经网络的准确率会下降。

这是我自定义神经网络的代码:

# custom class neural network class FashionMnistClassifier(nn.Module):  def __init__(self, n_inputs, n_out):    super().__init__()    self.cnn1 = nn.Conv2d(n_inputs, out_channels=32, kernel_size=5).cuda(device)    self.cnn2 = nn.Conv2d(32, out_channels=64, kernel_size=5).cuda(device)    #self.cnn3 = nn.Conv2d(n_inputs, out_channels=32, kernel_size=5)    self.fc1 = nn.Linear(64*4*4, out_features=100).cuda(device)    self.fc2 = nn.Linear(100, out_features=n_out).cuda(device)    self.relu = nn.ReLU().cuda(device)    self.pool = nn.MaxPool2d(kernel_size=2).cuda(device)    self.soft_max = nn.Softmax().cuda(device)  def forward(self, x):    x.cuda(device)    out = self.relu(self.cnn1(x))    out = self.pool(out)    out = self.relu(self.cnn2(out))    out = self.pool(out)    #print("out shape in classifier forward func: ", out.shape)    out = self.fc1(out.view(out.size(0), -1))    #out = self.relu(out) # 如果我取消注释这些行,准确率会从90%下降到50%!!!    out = self.fc2(out)    #out = self.relu(out) # 这行也是    return outn_batch = 100n_outputs = 10LR = 0.001model = FashionMnistClassifier(1, 10).cuda(device)optimizer = optim.Adam(model.parameters(), lr=LR)criterion = nn.CrossEntropyLoss()

所以,如果我只在CNN层之后使用ReLU,我可以得到90%的准确率,但当我取消注释那部分代码并在线性层之后使用ReLU激活函数时,准确率会下降到50%。我不知道为什么会发生这种情况,因为我认为在每个线性层之后使用激活函数总是能得到更好的分类准确率。我一直认为在分类问题中我们应该始终使用激活函数,而在线性回归中则不需要这样做,但在我的情况下,尽管这是一个分类问题,如果我在线性层之后不使用激活函数,我会得到更好的性能。能有人帮我解释一下吗?


回答:

CrossEntropyLoss 要求你传递未归一化的logits(最后一个Linear层的输出)。

如果你在最后一层使用ReLU作为输出,你只会输出在[0, inf)范围内的值,而神经网络倾向于对错误标签使用较小的值,对正确标签使用较高的值(我们可以说它对预测过于自信)。哦,还有,最高的logit值会被argmax选为正确标签。

所以这行代码绝对不会工作:

# out = self.relu(out) # 这行也是

尽管在它之前使用ReLU应该是可以的。请记住,更多的非线性并不总是对网络有好处。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注