提高或稳定KNN模型在IRIS数据集上的准确率得分的关键因素

提高或稳定这个基本KNN模型在IRIS数据集上的准确率得分不要有显著变化),可能有哪些关键因素?

尝试

from sklearn import neighbors, datasets, preprocessingfrom sklearn.model_selection import train_test_splitfrom sklearn.metrics import accuracy_scorefrom sklearn.metrics import classification_reportfrom sklearn.metrics import confusion_matrixiris = datasets.load_iris() X, y = iris.data[:, :], iris.targetXtrain, Xtest, y_train, y_test = train_test_split(X, y)scaler = preprocessing.StandardScaler().fit(Xtrain)Xtrain = scaler.transform(Xtrain)Xtest = scaler.transform(Xtest)knn = neighbors.KNeighborsClassifier(n_neighbors=4)knn.fit(Xtrain, y_train)y_pred = knn.predict(Xtest)print(accuracy_score(y_test, y_pred))print(classification_report(y_test, y_pred))print(confusion_matrix(y_test, y_pred))

样本准确率得分

0.97368421052631580.94736842105263151.00.9210526315789473

分类报告

              precision    recall  f1-score   support           0       1.00      1.00      1.00        12           1       0.79      1.00      0.88        11           2       1.00      0.80      0.89        15    accuracy                           0.92        38   macro avg       0.93      0.93      0.92        38weighted avg       0.94      0.92      0.92        38

样本混淆矩阵

[[12  0  0] [ 0 11  0] [ 0  3 12]]

回答:

我建议调整k-NN的k值。由于iris数据集较小且平衡性好,我会采取以下步骤:

对于范围在[2到10]内的每个k值(例如)  执行n次k折交叉验证(例如n=20,k=4)    存储准确率值(或其他任何指标)

根据平均值和方差绘制得分图,并选择最佳的k值。交叉验证的主要目标是估计测试误差,并据此选择最终模型。会有一些方差,但应该小于0.03左右。这取决于数据集和您选择的折数。一个好的方法是,对于每个k值,绘制所有20×4个准确率值的箱线图。选择下四分位数与上四分位数相交的k值,或者简单来说,如果准确率(或其他指标值)变化不大,则选择该k值。

一旦基于此选择了k值,目标是使用这个值在整个训练数据集上构建最终模型。接下来,可以用它来预测新数据。

另一方面,对于较大的数据集。创建一个单独的测试分区(如您在这里所做的那样),然后仅在训练集上调整k值(使用交叉验证,忽略测试集)。在选择合适的k值后,使用仅训练集进行训练。接下来,使用测试集报告最终值。永远不要基于测试集做出任何决策。

另一种方法是训练、验证、测试分区。使用训练集训练,并使用不同k值训练模型,然后使用验证分区进行预测并列出得分。基于此验证分区选择最佳得分。接下来,使用训练集或训练+验证集训练最终模型,使用基于验证集选择的k值。最后,取出测试集并报告最终得分。同样,永远不要在其他地方使用测试集。

这些是一般方法,适用于任何机器学习或统计学习方法。

在执行分区(训练、测试或交叉验证)时,请注意使用分层抽样,以便在每个分区中类别比例保持不变。

阅读更多关于交叉验证的内容。在scikitlearn中很容易实现。如果使用R,您可以使用caret

要记住的主要事情是,目标是训练一个在新数据上具有泛化能力的函数,或者在新数据上表现良好,而不是仅仅在现有数据上表现良好。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注