超参数调整与前馈神经网络的过拟合 – 迷你批次纪元和交叉验证

我正在研究为使用PyTorch实现的前馈神经网络(FNN)开发一种超参数调整方法。我的原始FNN模型,命名为net,使用了带有纪元的迷你批次学习方法来实现:

#Parametersbatch_size = 50 #较大的批次大小会导致过拟合num_epochs = 1000 learning_rate = 0.01 #最初是0.01 - 也称为步长 - 训练过程中权重更新的量batch_no = len(x_train) // batch_size criterion = nn.CrossEntropyLoss()  #分类模型性能,其输出是0到1之间的概率值optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)for epoch in range(num_epochs):    if epoch % 20 == 0:        print('Epoch {}'.format(epoch+1))    x_train, y_train = shuffle(x_train, y_train)    # 迷你批次学习 - 由于批次大小 < n(批量梯度下降),但 > 1(随机梯度下降)    for i in range(batch_no):        start = i * batch_size        end = start + batch_size        x_var = Variable(torch.FloatTensor(x_train[start:end]))        y_var = Variable(torch.LongTensor(y_train[start:end]))        # 前向 + 反向 + 优化        optimizer.zero_grad()        ypred_var = net(x_var)        loss =criterion(ypred_var, y_var)        loss.backward()        optimizer.step()

最后,我在一个单独的测试集上测试我的模型。

我发现了一种使用随机搜索来调整超参数并实施K折交叉验证(RandomizedSearchCV)的方法。

我的问题是双重的(无意双关!),首先是理论性的:K折验证是否必要,或者是否能为迷你批次前馈神经网络带来任何好处?据我所见,迷你批次方法应该大致完成同样的工作,防止过拟合。

我也在这里找到了一个很好的答案,但我不确定这是否特别针对迷你批次方法进行了讨论。

其次,如果不需要K折验证,是否有其他用于PyTorch的超参数调整函数,以避免手动创建一个?


回答:

  • K折交叉验证通常在数据集非常小的时候有用。因此,如果你在像CIFAR10这样的大数据集上进行训练(60000张图像),那么你不需要K折交叉验证。
  • K折交叉验证的目的是观察模型性能(泛化能力)如何随着用于训练和测试的不同数据子集的变化而变化。当数据非常少时,这一点变得重要。然而,对于大型数据集,测试数据集上的指标结果足以测试模型的泛化能力。
  • 因此,是否需要K折交叉验证取决于你的数据集大小。这并不取决于你使用什么模型。
  • 如果你查看深度学习书籍的这一章(首次在这个链接中引用):

小批次可以提供正则化效果(Wilson和Martinez,2003年),可能是因为它们为学习过程增加了噪声。对于批次大小为1的泛化误差通常是最好的。由于梯度估计的高方差,使用如此小的批次大小进行训练可能需要较小的学习率来保持稳定性。结果,总运行时间可能很高,因为需要进行更多的步骤,既因为学习率降低,也因为需要更多的步骤才能观察到整个训练集。

  • 因此,是的,迷你批次训练在一定程度上会具有正则化效果(减少过拟合)。
  • 没有内置的超参数调整功能(至少在撰写本回答时没有),但许多开发者已经为此目的开发了工具(例如这个)。你可以通过搜索找到更多这样的工具。这个问题的答案列出了很多这样的工具。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注