在scikit-learn的StratifiedShuffleSplit中’n_splits’的作用是什么?

我在阅读《Hands-on Machine Learning with Scikit-Learn and Tensorflow》这本书时,发现了以下代码:

from sklearn.model_selection import StratifiedShuffleSplitsplit = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)for train_index, test_index in split.split(housing, housing["income_cat"]):    strat_train_set = housing.loc[train_index]    strat_test_set = housing.loc[test_index]

我想了解’n_splits’参数的作用。我到处搜索了但没有找到满意的答案。提前感谢!


回答:

顾名思义,n_splits参数用于指定你希望进行多少次分割(基本上是多少个独立的分割)。

例如,设置n_splits = 3 将会使循环生成3个不同的分割(每次迭代一个),这样你可以更有效地进行验证。

设置n_splits = 1 会模拟 sklearn.model_selection.train_test_split 的行为(连同提到的stratify参数)。文档对这个函数的每个参数都有详细的解释。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注