使用PyTorch训练的神经网络对每个输入输出均值

我正在使用PyTorch来训练我的神经网络识别来自MNIST数据库的数字。

import torchimport torchvision

我想实现一个非常简单的设计,类似于3Blue1Brown关于神经网络的视频系列中展示的。特别是以下设计达到了1.6%的错误率。

class Net(torch.nn.Module):    def __init__(self):        super(Net, self).__init__()        self.layer1 = torch.nn.Linear(784, 800)        self.layer2 = torch.nn.Linear(800, 10)    def forward(self, x):        x = torch.sigmoid(self.layer1(x))        x = torch.sigmoid(self.layer2(x))        return x

数据是通过torchvision收集并组织成每个包含32张图片的小批次的。

batch_size = 32training_set = torchvision.datasets.MNIST("./", download=True, transform=torchvision.transforms.ToTensor())training_loader = torch.utils.data.DataLoader(training_set, batch_size=32)

我使用均方误差作为损失函数,并使用学习率为0.001的随机梯度下降作为我的优化算法。

net = Net()loss_function = torch.nn.MSELoss()optimizer = torch.optim.SGD(net.parameters(), lr=0.001)

最后,神经网络通过以下代码进行训练并保存:

for images, labels in training_loader:    optimizer.zero_grad()    for i in range(batch_size):        output = net(torch.flatten(images[i]))        desired_output = torch.tensor([float(j == labels[i]) for j in range(10)])        loss = loss_function(output, desired_output)        loss.backward()    optimizer.step()torch.save(net.state_dict(), "./trained_net.pth")

然而,以下是一些测试图片的输出结果:

tensor([0.0978, 0.1225, 0.1018, 0.0961, 0.1022, 0.0885, 0.1007, 0.1077, 0.0994,        0.1081], grad_fn=<SigmoidBackward>)tensor([0.0978, 0.1180, 0.1001, 0.0929, 0.1006, 0.0893, 0.1010, 0.1051, 0.0978,        0.1067], grad_fn=<SigmoidBackward>)tensor([0.0981, 0.1227, 0.1018, 0.0970, 0.0979, 0.0908, 0.1001, 0.1092, 0.1011,        0.1088], grad_fn=<SigmoidBackward>)tensor([0.1061, 0.1149, 0.1037, 0.1001, 0.0957, 0.0919, 0.1044, 0.1022, 0.0997,        0.1052], grad_fn=<SigmoidBackward>)tensor([0.0996, 0.1137, 0.1005, 0.0947, 0.0977, 0.0916, 0.1048, 0.1109, 0.1013,        0.1085], grad_fn=<SigmoidBackward>)tensor([0.1008, 0.1154, 0.0986, 0.0996, 0.1031, 0.0952, 0.0995, 0.1063, 0.0982,        0.1094], grad_fn=<SigmoidBackward>)tensor([0.0972, 0.1235, 0.1013, 0.0984, 0.0974, 0.0907, 0.1032, 0.1075, 0.1001,        0.1080], grad_fn=<SigmoidBackward>)tensor([0.0929, 0.1258, 0.1016, 0.0978, 0.1006, 0.0889, 0.1001, 0.1068, 0.0986,        0.1024], grad_fn=<SigmoidBackward>)tensor([0.0982, 0.1207, 0.1040, 0.0990, 0.0999, 0.0910, 0.0980, 0.1051, 0.1039,        0.1078], grad_fn=<SigmoidBackward>)

如您所见,网络似乎接近于对每个输入的答案都是:

[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1]

这个神经网络的表现还不如随机猜测。我在设计或代码中哪里出错了?


回答:

以下是一些对您有用的建议:

  • 乍一看,您的模型似乎没有在学习,因为您的预测与随机猜测一样好。首先,您应该监控您的损失,您这里只有一个epoch。至少您可以对未见过的数据进行评估:

    validation_set = torchvision.datasets.MNIST('./',     download=True, train=False, transform=T.ToTensor())validation_loader = DataLoader(validation_set, batch_size=32)
  • 您使用均方误差(L2范数)来训练分类任务,这不是这种任务的合适工具。您可以改用负对数似然。PyTorch提供了nn.CrossEntropyLoss,它包含了log-softmax和负对数似然损失在一个模块中。可以通过添加以下代码来实现这个更改:

    loss_function = nn.CrossEntropyLoss()

    并在应用loss_function时使用正确的目标形状(*见下文*)。由于损失函数会应用log-softmax,您不应该在模型输出上使用激活函数。

  • 您使用sigmoid作为激活函数,中间的非线性激活函数使用ReLU (见相关帖子) 会更好。sigmoid更适合二分类任务。同样,由于我们使用nn.CrossEntropyLoss,我们必须移除layer2后的激活函数。

    class Net(torch.nn.Module):    def __init__(self):        super(Net, self).__init__()        self.flatten = nn.Flatten()        self.layer1 = torch.nn.Linear(784, 800)        self.layer2 = torch.nn.Linear(800, 10)    def forward(self, x):        x = self.flatten(x)        x = torch.relu(self.layer1(x))        x = self.layer2(x)        return x
  • 一个不太关键的点是,您可以对整个批次进行推断,而不是逐个元素地遍历批次。一个典型的单epoch训练循环看起来像这样:

    for images, labels in training_loader:    optimizer.zero_grad()    output = net(images)    loss = loss_function(output, labels)    loss.backward()    optimizer.step()

通过这些修改,您可以期待在单个epoch后达到大约80%的验证准确率。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注