使用train_test_split后分类器准确率达到100%

我正在处理蘑菇分类数据集(数据集在这里可以找到:https://www.kaggle.com/uciml/mushroom-classification)。

我试图将数据分成训练集和测试集以用于我的模型,然而当我使用train_test_split方法时,我的模型总是能达到100%的准确率。但当我手动分割数据时,情况并非如此。

x = data.copy()y = x['class']del x['class']x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33)model = xgb.XGBClassifier()model.fit(x_train, y_train)predictions = model.predict(x_test)print(confusion_matrix(y_test, predictions))print(accuracy_score(y_test, predictions))

这会产生以下结果:

[[1299    0] [   0 1382]]1.0

如果我手动分割数据,我会得到一个更合理的结果。

x = data.copy()y = x['class']del x['class']x_train = x[0:5443]x_test = x[5444:]y_train = y[0:5443]y_test = y[5444:]model = xgb.XGBClassifier()model.fit(x_train, y_train)predictions = model.predict(x_test)print(confusion_matrix(y_test, predictions))print(accuracy_score(y_test, predictions))

结果:

[[2007    0] [ 336  337]]0.8746268656716418

是什么导致了这种行为?

编辑:根据请求,我包括了切片的形状。

train_test_split:

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.33)print(x_train.shape)print(y_train.shape)print(x_test.shape)print(y_test.shape)

结果:

(5443, 64)(5443,)(2681, 64)(2681,)

手动分割:

x_train = x[0:5443]x_test = x[5444:]y_train = y[0:5443]y_test = y[5444:]print(x_train.shape)print(y_train.shape)print(x_test.shape)print(y_test.shape)

结果:

(5443, 64)(5443,)(2680, 64)(2680,)

我尝试定义自己的分割函数,结果也导致分类器准确率达到100%。

以下是分割函数的代码

def split_data(dataFrame, testRatio):  dataCopy = dataFrame.copy()  testCount = int(len(dataFrame)*testRatio)  dataCopy = dataCopy.sample(frac = 1)  y = dataCopy['class']  del dataCopy['class']  return dataCopy[testCount:], dataCopy[0:testCount], y[testCount:], y[0:testCount]

回答:

你在使用train_test_split时运气不错。你手动进行的分割可能包含了更多的未见数据,这比train_test_split(它内部会打乱数据进行分割)的验证效果更好。

为了进行更好的验证,请使用K折交叉验证,这样可以验证模型在数据的不同部分作为测试集和其余部分作为训练集时的准确性。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注