我在交叉验证中需要多个模型吗?

我看到过不同的交叉验证实现方法。我目前使用PyTorch来训练神经网络。我的当前布局如下:我有6个独立的数据集。其中5个用于交叉验证。

Network_1 trains on Datasets: 1,2,3,4 computes loss on 5Network_2 trains on Datasets: 1,2,3,5 computes loss on 4Network_3 trains on Datasets: 1,2,4,5 computes loss on 3Network_4 trains on Datasets: 1,3,4,5 computes loss on 2Network_5 trains on Datasets: 2,3,4,5 computes loss on 1

然后进入第二个周期,我再次执行相同的操作:

Network_1 trains on Datasets: 1,2,3,4 computes loss on 5Network_2 trains on Datasets: 1,2,3,5 computes loss on 4Network_3 trains on Datasets: 1,2,4,5 computes loss on 3Network_4 trains on Datasets: 1,3,4,5 computes loss on 2Network_5 trains on Datasets: 2,3,4,5 computes loss on 1

对于测试数据集6,我应该合并所有5个网络的预测,并取预测的平均分数(我还需要对预测矩阵进行平均)。

我是否正确理解了交叉验证?这是它应该的工作方式吗?这会正常工作吗?我已经努力避免使用已经训练过的数据进行测试。我仍然不…

非常感谢您的帮助 🙂


回答:

您当然可以将交叉验证应用于神经网络,但由于神经网络是计算量大的模型,通常不会这样做。为了减少方差,神经网络通常会应用其他技术,如早期停止或丢弃(dropout)。

尽管如此,我不确定您是否以正确的方式应用了它。您应该在所有周期内进行训练,以便:

Network_1 trains on Datasets: 1,2,3,4 up to the end of training. Then computes loss on 5Network_2 trains on Datasets: 1,2,3,5 up to the end of training. Then computes loss on 4Network_3 trains on Datasets: 1,2,4,5 up to the end of training. Then computes loss on 3Network_4 trains on Datasets: 1,3,4,5 up to the end of training. Then computes loss on 2Network_5 trains on Datasets: 2,3,4,5 up to the end of training. Then computes loss on 1

一旦每个网络都训练到结束(即跨越所有周期),并在留出的数据集上进行验证(称为验证数据集),您就可以平均您获得的分数。
这个分数(以及交叉验证的真正意义)应该能给您一个公正的模型评估,当您在测试集(从一开始就未用于训练的数据集)上测试模型时,分数不应下降。

交叉验证通常与某种形式的网格搜索结合使用,以产生不同模型的无偏评估形式。所以,如果您想比较NetworkANetworkB,它们在某些参数上有所不同,您可以对NetworkA进行交叉验证,对NetworkB进行交叉验证,然后选择交叉验证得分最高的那个作为最终模型。

最后一步,一旦您决定了哪个是最佳模型,您通常会重新训练您的模型,使用您在训练集中所有的数据(即在您的例子中是数据集1,2,3,4,5),然后在测试集(数据集6)上测试这个模型。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注