如何使用scikit-learn在Python中打印简单线性回归的截距和斜率?

我正在尝试使用简单线性回归(只有一个自变量)来预测汽车价格(通过机器学习)。变量是“高速公路每加仑英里数”

0      271      272      263      304      22       ..200    28201    25202    23203    27204    25Name: highway-mpg, Length: 205, dtype: int64

和“价格”:

0      13495.01      16500.02      16500.03      13950.04      17450.0        ...   200    16845.0201    19045.0202    21485.0203    22470.0204    22625.0Name: price, Length: 205, dtype: float64

使用以下代码:

from sklearn.linear_model import LinearRegressionx = df["highway-mpg"]y = df["price"]lm = LinearRegression()lm.fit([x],[y])Yhat = lm.predict([x])print(Yhat)print(lm.intercept_)print(lm.coef_)

然而,截距和斜率系数的打印命令给出了以下输出:

[[0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] ... [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.] [0. 0. 0. ... 0. 0. 0.]]

为什么它不打印截距和斜率系数呢?“Yhat”的打印命令确实正确地打印了一个数组中的预测值,但不知为何其他打印命令没有输出我想要的结果…


回答:

导致coef_intercept_看起来奇怪的原因是你的数据有205个特征和205个目标,但只有1个样本。这绝对不是你想要的!

你可能想要1个特征,205个样本,以及1个目标。要做到这一点,你需要重塑你的数据:

from sklearn.linear_model import LinearRegressionimport numpy as npmpg = np.array([27, 27, 26, 30, 22, 28, 25, 23, 27, 25]).reshape(-1, 1)price = np.array([13495.0, 16500.0, 16500.0, 13950.0, 17450.0, 16845.0, 19045.0, 21485.0, 22470.0, 22625.0])lm = LinearRegression()lm.fit(mpg, price)print(lm.intercept_)print(lm.coef_)

我在这里使用数组进行测试,但显然你应该使用数据框中的数据。

附注: 如果你省略了重塑操作,你会得到这样的错误消息:

ValueError: Expected 2D array, got 1D array instead:array=[27 27 26 30 22 28 25 23 27 25].Reshape your data either using array.reshape(-1, 1) if your data has a single feature or array.reshape(1, -1) if it contains a single sample.

^ 它告诉你该怎么做!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注