在阅读这本书的过程中,我看到了关于scikit-learn的这一部分内容:
housing["income_cat"] = pd.cut(housing["median_income"], bins=[0.,1.5,3.0,4.5,6.,np.inf], labels=[1,2,3,4,5])split =StratifiedShuffleSplit(n_splits=1, test_size=0.2, randomstate=42)for train_index, test_index in split.split(housing, housing["income_cat"]) stat_train_set = housing.loc[train_index] stat_test_set = housing.loc[test_index]
我理解第一行代码是在housing数据框中添加一个列,并根据收入将数据分为1到5的类别。
1: 0-<1.52: 1.5-<3.03: 3.0-<4.54: 4.5-<65: >6
我明白第二行代码返回的是一个用于分割的函数。
我不理解的是这个函数如何知道哪个索引是20%?第二个索引总是应用test_size参数的那个吗?
回答:
你只需要知道split方法会产生一个迭代器,这个迭代器会生成一个索引元组。元组的第一个元素是训练索引,第二个是测试索引。这背后没有魔法。
如果你想查看这个方法的源代码,可以在这里找到:链接。特别要注意_iter_indices
方法的末尾,你会看到yield
语句,它产生这个元组。