PyTorch NN不学习或学习效果差

我在使用PyTorch教程,稍作修改以使用泰坦尼克号数据集。我使用了一个非常简单的网络,由带ReLU的线性(密集)层组成…例如,我想根据年龄、票价和性别来预测生存状态。

我在一款简单的神经网络上遇到了奇怪的行为(我在Google Colab上进行实验)。有时候当我执行训练时,准确率完全没有变化。这很奇怪,因为我重新创建了模型…

准确率: 59.4%, 平均损失: 0.693147[...50行或更多类似的行...]准确率: 59.4%, 平均损失: 0.693147

有时候准确率会从60%缓慢增加到80%。

另一件事是,尽管我用的是…同一个训练集进行验证,但准确率非常低(在60-80%之间变化)!

我尝试了几种不同的学习率、批次大小和训练轮数的组合,也尝试了不同数量的神经元,但它仍然表现得非常…不可预测且效果不佳。你能告诉我为什么有时候网络根本不学习吗?如果我重新运行几次,它又开始学习了。应该如何改进这个网络?

这是我的Python笔记本:https://colab.research.google.com/drive/1-50BTqnMgiz_dozv1DjXS9advD1Rxd-B?usp=sharing

提前感谢!


回答:

由于这是一个分类问题,你的神经网络的最后一层不应该使用relu激活函数。

代码片段:

class NeuralNetwork(nn.Module):    def __init__(self):        super(NeuralNetwork, self).__init__()        self.fun = nn.Sequential(            nn.Linear(3, 32),            nn.ReLU(),            nn.Linear(32, 1)        )    def forward(self, x):        return self.fun(x)

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注