如何在k折交叉验证后绘制每个折叠的数据和模型拟合?

我正在尝试根据一个特征预测一个标签变量。这两者似乎具有很高的线性相关性。我选择了线性回归模型来描述数据。我的代码输出显示了训练和测试数据的R2分数。我的模型表现很好,但测试样本中的一个折叠除外,其R2为负。我希望绘制每个折叠的数据和模型的拟合情况,以便了解哪里出了问题。然而,从Python编码的角度来看,我无法弄清楚如何做到这一点。

有谁能帮忙吗?

Test_scores = list()Train_scores =list()n_splits = 5kfold = KFold(n_splits=n_splits              , shuffle=False)for train_ix, test_ix in kfold.split(Feature_X):    Train_Feature_X, Test_Feature_X=Feature_X[train_ix], Feature_X[test_ix]    Train_label_X, Test_label_X= label_X[train_ix],label_X[test_ix]    model = LinearRegression()    model.fit(Train_Feature_X, Train_label_X)    pred_label_train = model.predict(Train_Feature_X)    acc_train = r2_score(Train_label_X, pred_label_train)    pred_label_test = model.predict(Test_Feature_X)    acc_test = r2_score(Test_label_X, pred_label_test)    Test_scores.append(acc_test)    Train_scores.append(acc_train)    print('> ', 'Train:'+ str(acc_train), "Test:"+ str(acc_test))Test_mean, Test_std = np.mean(Test_scores), np.std(Test_scores)Train_mean, Train_std = np.mean(Train_scores), np.std(Train_scores)print('Mean of test: %.3f, Standard Deviation: %.3f' % (Test_mean, Test_std))print('Mean of train: %.3f, Standard Deviation: %.3f' % (Train_mean, Train_std))

代码输出:

>  Train:0.9948113361306588 Test:0.9715872368615199>  Train:0.9905854864161807 Test:0.9917503220348162>  Train:0.9888929852977923 Test:-4.996610921978263>  Train:0.990942242777374 Test:0.9960355777732937>  Train:0.9925744355834707 Test:0.9458246438971184Mean of test: -0.218, Standard Deviation: 2.389Mean of train: 0.992, Standard Deviation: 0.002

enter image description here


回答:

您可以直接在循环中添加绘图代码。

在每次迭代中,您可以访问训练-测试折叠和预测结果,因此在打印值之前,即print('> ', 'Train:'+ str(acc_train), "Test:"+ str(acc_test)),您可以做类似以下的事情:

fig, ax = plt.subplots(nrows=1, ncols=5)curr_split = 1for ...    plt.subplot(1, 5, curr_split)    plt.plot(x, y)    curr_split += 1plt.show()

这将绘制5个子图,每个子图代表一个折叠。

请注意,这只是您应该做的事情的一般大纲,请参考以下链接中的文档 plt.subplots()

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注