PyTorch简单模型未见改善

我正在使用PyTorch构建一个简单的神经网络,以近似x = [0, 2π]上的正弦函数。这是我在不同深度学习库中使用的一个简单架构,用来测试我是否理解如何使用它。未经训练的神经网络总是产生一条水平直线,而经过训练后,则在y = 0处产生一条直线。一般来说,它总是在y =(函数的平均值)处产生一条直线。这让我认为它的前向传播部分可能出了问题,因为未经训练时边界不应该只是一条直线。以下是网络的代码:

class Net(nn.Module):    def __init__(self):      super(Net, self).__init__()      self.model = nn.Sequential(      nn.Linear(1, 20),      nn.Sigmoid(),      nn.Linear(20, 50),      nn.Sigmoid(),      nn.Linear(50, 50),      nn.Sigmoid(),      nn.Linear(50, 1)      )    def forward(self, x):        x = self.model(x)        return x

这是训练循环的代码

def train(net, trainloader, valloader, learningrate, n_epochs):    net = net.train()    loss = nn.MSELoss()    optimizer = torch.optim.SGD(net.parameters(), lr = learningrate)    for epoch in range(n_epochs):        for X, y in trainloader:            X = X.reshape(-1, 1)            y = y.view(-1, 1)            optimizer.zero_grad()            outputs = net(X)            error   = loss(outputs, y)            error.backward()            #net.parameters()  net.parameters() * learningrate            optimizer.step()        total_loss = 0        for X, y in valloader:            X = X.reshape(-1, 1).float()            y = y.view(-1, 1)            outputs = net(X)            error   = loss(outputs, y)            total_loss += error.data        print('Val loss for epoch', epoch, 'is', total_loss / len(valloader) )

它的调用方式如下:

net = Net()losslist = train(net, trainloader, valloader, .0001, n_epochs = 4)

其中trainloader和valloader分别是训练和验证数据加载器。谁能帮我看看这是哪里出了问题?我知道这不是学习率的问题,因为这是我在其他框架中使用的相同学习率,我知道这不是因为我使用了SGD或Sigmoid激活函数,尽管我怀疑错误可能出在激活函数的某个地方。

有谁知道如何解决这个问题?谢谢。


回答:

经过一段时间调整一些超参数、修改网络和更换优化器(遵循这个优秀的配方)后,我最终将代码行optimizer = torch.optim.SGD(net.parameters(), lr = learningrate)改为optimizer = torch.optim.Adam(net.parameters())(使用了默认的优化器参数),运行了100个周期,批量大小为1。

以下是使用的代码(仅在CPU上测试):

使用Adam优化器的结果:enter image description here

使用SGD优化器的结果:enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注