为什么线性回归模型的目标和预测可视化不准确?

使用多元线性回归模型来估算吸烟者的医疗费用。我使用了’age’(年龄)、’bmi’(体重指数)、’children’(子女数量)这三个特征来估算“charges”(费用)。以下是我的代码:

import pandas as pdimport numpy as npimport plotly.express as pximport matplotlibimport matplotlib.pyplot as pltimport seaborn as sns%matplotlib inlinefrom sklearn.linear_model import LinearRegressionfrom sklearn.metrics import mean_squared_error as rmse

从GitHub仓库读取数据

smoker_df = pd.read_csv('https://raw.githubusercontent.com/stedy/Machine-Learning-with-R-datasets/master/insurance.csv')

创建输入和目标

inputs  = smoker_df[['age', 'bmi', 'children']]targets = smoker_df['charges']

创建并训练模型

model6 = LinearRegression().fit(inputs, targets)

生成预测

predictions = model6.predict(inputs)

计算损失以评估模型

loss = rmse(targets, predictions)print('Loss:', loss)

预测和目标的可视化:

fig, ax = plt.subplots(figsize=(7, 3.5))ax.plot(predictions, targets, color='k', label='Regression model')ax.set_ylabel('predictions', fontsize=14)ax.set_xlabel('targets', fontsize=14)ax.legend(facecolor='white', fontsize=11)

这不是一个好的可视化效果。我该如何改进它,以便获得一些见解?以及如何可视化超过三个特征作为输入和一个特征作为目标?Output

数据来源


回答:

您可以使用散点图来可视化预测与观察值的对比:

fig, ax = plt.subplots(figsize=(7, 3.5))ax.scatter(predictions, targets)ax.set_xlabel('prediction', fontsize=14)ax.set_ylabel('charges', fontsize=14)ax.legend(facecolor='white', fontsize=11)

enter image description here

您可以看到一些预测值存在偏差,这是因为您没有包括其他变量:

import seaborn as snssns.scatterplot(data=smoker_df,x = "age", y = "charges",hue="smoker")

enter image description here

您还可以检查其他特征与目标的相关性:

fig, ax = plt.subplots(1,3,figsize=(15, 5))for i,x in enumerate(inputs.columns):    ax[i].scatter(inputs[[x]], targets, label=x)    ax[i].set_xlabel(x, fontsize=14)    ax[i].set_ylabel('charges', fontsize=14)    ax[i].legend(facecolor='white', fontsize=11)plt.tight_layout()

enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注