尝试理解机器学习示例脚本

我正在尝试研究一个关于机器学习的示例脚本:线性模型系数解释中的常见陷阱,但我对其中的一些步骤感到困惑。脚本的开头看起来是这样的:

import numpy as npimport scipy as spimport pandas as pdimport matplotlib.pyplot as pltimport seaborn as snsfrom sklearn.datasets import fetch_openmlsurvey = fetch_openml(data_id=534, as_frame=True)# 我们识别特征 `X` 和目标 `y`:列 WAGE 是我们的# 目标变量(即我们想要预测的变量)。X = survey.data[survey.feature_names]X.describe(include="all")X.head()# 我们预测的目标是工资。y = survey.target.values.ravel()survey.target.head()from sklearn.model_selection import train_test_splitX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)train_dataset = X_train.copy()train_dataset.insert(0, "WAGE", y_train)_ = sns.pairplot(train_dataset, kind='reg', diag_kind='kde')

我的问题出在以下几行:

y = survey.target.values.ravel()survey.target.head()

如果我们在这几行之后立即检查 survey.target.head(),输出是

Out[36]: 0    5.101    4.952    6.673    4.004    7.50Name: WAGE, dtype: float64

模型如何知道 WAGE 是目标变量?不需要明确声明吗?


回答:

survey.target.values.ravel() 旨在将数组展平,但在本例中这是不必要的。survey.target 是一个 pandas Series(即一列数据框),而 survey.target.values 是一个 numpy 数组。由于 survey.target 只有一个列,你可以使用两者进行训练/测试拆分。

type(survey.target)pandas.core.series.Seriestype(survey.target.values)numpy.ndarray

如果我们只使用 survey.target,你可以看到回归仍然有效:

y = survey.targetX_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)train_dataset = X_train.copy()train_dataset.insert(0, "WAGE", y_train)sns.pairplot(train_dataset, kind='reg', diag_kind='kde')

enter image description here

如果你有另一个数据集,例如 iris,我想对花瓣宽度进行回归。你可以使用方括号 [] 来调用数据框的列:

from sklearn.datasets import load_irisfrom sklearn.linear_model import LinearRegressiondat = load_iris(as_frame=True).frameX = dat[['sepal length (cm)','sepal width (cm)','petal length (cm)']]y = dat[['petal width (cm)']]X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)LR = LinearRegression()LR.fit(X_train,y_train)plt.scatter(x=y_test,y=LR.predict(X_test))

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注