我在使用StratifiedShuffleSplit交叉验证器来预测波士顿房价数据集中的房价。当我运行下面的样例代码时:
def fit_model_S(labels, features,step, clf,parameters): cv = StratifiedShuffleSplit(n_splits=2,test_size=0.10, random_state = 42) print (cv) for train_index, test_index in cv.split(features,labels): labels_train, labels_test = labels[train_index], labels[test_index] features_train, features_test = features[train_index], features[test_index]
我得到了下面的错误。代码在使用ShuffleSplit时可以正常运行。这是否意味着StratifiedShuffleSplit不能用于数值标签?
---------------------------------------------------------------------------ValueError Traceback (most recent call last)<ipython-input-141-b290147edcbf> in <module>() 33 dt_steps = [('decision', clf)] 34 ---> 35 fit_model_S(labels, features,dt_steps,clf,parameters4) 36 37 <ipython-input-141-b290147edcbf> in fit_model_S(labels, features, step, clf, parameters) 8 cv = StratifiedShuffleSplit(n_splits=2,test_size=0.10, random_state = 42) 9 print (cv)---> 10 for train_index, test_index in cv.split(features,labels): 11 12 labels_train, labels_test = labels[train_index], labels[test_index]C:\ProgramData\Anaconda3\lib\site-packages\sklearn\model_selection\_split.py in split(self, X, y, groups) 1194 """ 1195 X, y, groups = indexable(X, y, groups)-> 1196 for train, test in self._iter_indices(X, y, groups): 1197 yield train, test 1198 C:\ProgramData\Anaconda3\lib\site-packages\sklearn\model_selection\_split.py in _iter_indices(self, X, y, groups) 1535 class_counts = np.bincount(y_indices) 1536 if np.min(class_counts) < 2:-> 1537 raise ValueError("The least populated class in y has only 1" 1538 " member, which is too few. The minimum" 1539 " number of groups for any class cannot"ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.
数据集样例如下:
RM LSTAT PTRATIO MEDV0 6.575 4.98 15.3 504000.01 6.421 9.14 17.8 453600.02 7.185 4.03 17.8 728700.03 6.998 2.94 18.7 701400.04 7.147 5.33 18.7 760200.0
在这种情况下,MEDV是标签。
回答:
波士顿房价数据是一个用于回归问题的数据库。你使用StratifiedShuffleSplit
来将其划分为训练集和测试集。正如文档中提到的,StratifiedShuffleSplit
是:
这种交叉验证对象是StratifiedKFold和ShuffleSplit的结合,它返回分层随机折叠。折叠是通过保持每个类别的样本百分比来创建的。
请注意最后一行:“保持每个类别的样本百分比”。因此,StratifiedShuffleSplit
试图将y
值视为各个类别。
但这将是不可能的,因为你的y
是一个回归变量(连续数值数据)。
请查看ShuffleSplit或train_test_split来划分你的数据。有关交叉验证的更多详细信息,请参见:http://scikit-learn.org/stable/modules/cross_validation.html#cross-validation