为什么在使用sklearn时,10折交叉验证的准确率比90-10的训练测试分割更差?

任务是通过神经网络进行二元分类。数据以字典形式存在,字典包含每个条目的复合名称(作为键)和标签(0或1,作为向量值的第三个元素)。第一个和第二个元素是复合名称的两部分,后来用于提取相应的特征。

在这两种情况下,字典被转换成两个数组,以便对多数类(在数据中占66%)进行平衡欠采样:

data_for_sampling = np.asarray([key for key in list(data.keys())])labels_for_sampling = [element[2] for element in list(data.values())]sampler = RandomUnderSampler(sampling_strategy = 'majority')data_sampled, label_sampled = sampler.fit_resample(data_for_sampling.reshape(-1, 1), labels_for_sampling)

然后,使用Kfold方法将重新采样的名称和标签数组用于创建训练和测试集:

kfolder = KFold(n_splits = 10, shuffle = True)kfolder.get_n_splits(data_sampled)for train_index, test_index in kfolder.split(data_sampled):        data_train, data_test = data_sampled[train_index], data_sampled[test_index]

或者使用train_test_split方法:

data_train, data_test, label_train, label_test = train_test_split(data_sampled, label_sampled, test_size = 0.1, shuffle = True)

最后,使用data_train和data_test中的名称从原始字典中重新提取相关条目(通过键),然后用于收集这些条目的特征。就我而言,10折集的单一分割应提供与90-10的train_test_split类似的训练-测试数据分布,并且在训练过程中似乎确实如此,在仅运行一个epoch后,两个训练集的结果均为约0.82的准确率,分别使用model.fit()运行。然而,当使用model.evaluate()在上述epoch后对相应模型在测试集上进行评估时,来自train_test_split的集合给出的分数为约0.86,而来自Kfold的集合为约0.72。我进行了多次测试以查看是否只是一个不受限制的坏随机种子,但结果保持不变。这些集合还具有正确的平衡标签分布和看似良好的洗牌条目。


回答:

事实证明,问题源自多种因素的组合:

虽然train_test_split()方法中的shuffle = True首先正确地洗牌提供的数据,然后将其分割成所需的部分,但在Kfold方法中的shuffle = True只会导致随机构建的折叠,然而折叠内的数据保持有序

这是文档中指出的,感谢此帖:https://github.com/scikit-learn/scikit-learn/issues/16068

现在,在学习过程中,我的自定义训练函数再次对训练数据应用洗牌,以确保万无一失,但它不会洗牌测试数据。此外,如果没有给出参数,model.evaluate()默认使用batch_size = 32,这与有序的测试数据结合,导致验证准确率的差异。测试数据确实存在缺陷,因为它包含大量“难以预测”的条目,这些条目由于排序而聚集在一起,似乎拉低了结果中的平均准确率。正如TC Arlen所指出的,完成所有N折的运行可能最终确实提供了更精确的估计,但我期望在仅一个折叠后得到更接近的结果,这导致了这一问题的发现。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注