我在阅读《Hands-on Machine Learning with Scikit-Learn and Tensorflow》这本书时,发现了以下代码:
from sklearn.model_selection import StratifiedShuffleSplitsplit = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)for train_index, test_index in split.split(housing, housing["income_cat"]): strat_train_set = housing.loc[train_index] strat_test_set = housing.loc[test_index]
我想了解’n_splits’参数的作用。我到处搜索了但没有找到满意的答案。提前感谢!
回答:
顾名思义,n_splits参数用于指定你希望进行多少次分割(基本上是多少个独立的分割)。
例如,设置n_splits = 3 将会使循环生成3个不同的分割(每次迭代一个),这样你可以更有效地进行验证。
设置n_splits = 1 会模拟 sklearn.model_selection.train_test_split 的行为(连同提到的stratify参数)。文档对这个函数的每个参数都有详细的解释。