在试图找出train_test_split
和StratifiedShuffleSplit
之间的区别时,我发现了以下声明。
当
stratify
不为None时,train_test_split
内部使用StratifiedShuffleSplit
,
我只是在想,为什么当我们可以使用train_test_split
中的stratify
参数时,还要使用sklearn.model_selection
中的StratifiedShuffleSplit
。
回答:
主要是为了可重用性。相比于复制已经为StratifiedShuffleSplit
实现的代码,train_test_split
只是调用那个类。出于同样的原因,当stratify=False
时,它使用model_selection.ShuffleSplit
类(参见源代码)。
请注意,复制代码被认为是一种不好的做法,因为它被认为会增加维护成本,而且由于对代码复制的更改不一致,可能会导致意外的行为。如果你想了解更多,这里有一篇参考文献。
此外,尽管它们执行相同的任务,但它们并不能在所有情况下都通用。例如,train_test_split
不能在使用sklearn.model_selection.RandomizedSearchCV
或sklearn.model_selection.GridSearchCV
进行随机或网格搜索时使用。而StratifiedShuffleSplit
可以。原因是前者不是“一个生成(train, test)分割作为索引数组的可迭代对象”。而后者有一个方法split
,它生成(train, test)分割作为索引数组。更多信息请参见这里(查看参数cv)。