模型置信度过低

我想知道为什么模型的得分非常低,只有0.13。我已经确保数据是干净的、经过缩放的,并且各特征之间有很高的相关性,但使用线性回归的模型得分仍然很低,这是为什么呢?如何解决这个问题?这是我的代码

import numpy as np from sklearn.linear_model import LinearRegressionfrom sklearn.model_selection import train_test_splitimport pandas as pdimport seaborn as snsimport matplotlib.pyplot as pltfrom sklearn import preprocessingpath = r"D:\python projects\avocado.csv"df = pd.read_csv(path)df = df.reset_index(drop=True)df.set_index('Date', inplace=True)df = df.drop(['Unnamed: 0','year','type','region','AveragePrice'],1)df.rename(columns={'4046':'Small HASS sold',                          '4225':'Large HASS sold',                          '4770':'XLarge HASS sold'},                  inplace=True)print(df.head)sns.heatmap(df.corr())sns.pairplot(df)df.plot()_=plt.xticks(rotation=20)forecast_line = 35df['target'] = df['Total Volume'].shift(-forecast_line)X = np.array(df.drop(['target'], 1))X = preprocessing.scale(X)X_lately = X[-forecast_line:]X = X[:-forecast_line]df.dropna(inplace=True)y = np.array(df['target'])X_train, X_test, y_train, y_test = train_test_split(X, y,test_size=0.2)lr = LinearRegression()lr.fit(X_train,y_train)confidence = lr.score(X_test,y_test)print(confidence)

这是我使用的数据集的链接

https://www.kaggle.com/neuromusic/avocado-prices


回答:

所以你使用的评分函数是:

返回预测的决定系数R^2。

决定系数R^2定义为(1 – u/v),其中u是残差平方和((y_true – y_pred) ** 2).sum(),v是总平方和((y_true – y_true.mean()) ** 2).sum()。最佳可能得分为1.0,并且可以是负值(因为模型可能任意差)。一个总是预测y的期望值的常数模型,不考虑输入特征,会得到0.0的R^2得分。

所以正如你所意识到,你已经超过了常数预测。

我的建议是尝试绘制你的数据,看看应该使用哪种回归。这里你可以看到可用的线性回归类型概览:https://scikit-learn.org/stable/modules/linear_model.html

如果你的数据呈现出逻辑曲线,即你的点要么接近0要么接近1,中间点不多,那么使用逻辑回归是有意义的。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注