动手学习机器学习:加州房价问题

在阅读这本书的过程中,我看到了关于scikit-learn的这一部分内容:

housing["income_cat"] = pd.cut(housing["median_income"], bins=[0.,1.5,3.0,4.5,6.,np.inf], labels=[1,2,3,4,5])split =StratifiedShuffleSplit(n_splits=1, test_size=0.2, randomstate=42)for train_index, test_index in split.split(housing, housing["income_cat"])    stat_train_set = housing.loc[train_index]    stat_test_set = housing.loc[test_index]

我理解第一行代码是在housing数据框中添加一个列,并根据收入将数据分为1到5的类别。

1: 0-<1.52: 1.5-<3.03: 3.0-<4.54: 4.5-<65: >6 

我明白第二行代码返回的是一个用于分割的函数。

我不理解的是这个函数如何知道哪个索引是20%?第二个索引总是应用test_size参数的那个吗?


回答:

你只需要知道split方法会产生一个迭代器,这个迭代器会生成一个索引元组。元组的第一个元素是训练索引,第二个是测试索引。这背后没有魔法。

如果你想查看这个方法的源代码,可以在这里找到:链接。特别要注意_iter_indices方法的末尾,你会看到yield语句,它产生这个元组。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注