如何预测特征数量与测试集可用特征数量不匹配的情况? [重复]

我使用 pandas 的 get_dummies 函数将分类变量转换为虚拟/指示变量,这会在数据集中引入新的特征。然后我们将这个数据集拟合/训练到模型中。

由于 X_trainX_test 的维度保持相同,当我们对测试数据进行预测时,它可以很好地处理测试数据 X_test

现在假设我们有另一份 csv 文件中的测试数据(具有未知输出)。当我们使用 get_dummies 转换这组测试数据时,生成的数据集可能没有与我们训练模型时相同的特征数量。后来当我们用这个数据集使用我们的模型时,它会失败,因为测试集中的特征数量与模型的特征数量不匹配。

有什么办法可以处理这种情况吗?

代码:

import pandas as pdfrom sklearn.tree import DecisionTreeClassifierfrom sklearn.model_selection import train_test_split# 加载数据集in_file = 'train.csv'full_data = pd.read_csv(in_file)outcomes = full_data['Survived']features_raw = full_data.drop('Survived', axis = 1)features = pd.get_dummies(features_raw)features = features.fillna(0.0)X_train, X_test, y_train, y_test = train_test_split(features, outcomes, test_size=0.2, random_state=42)model = DecisionTreeClassifier(max_depth=50,min_samples_leaf=6,min_samples_split=2)model.fit(X_train,y_train)y_train_pred = model.predict(X_train)#print (X_train.shape)y_test_pred = model.predict(X_test)from sklearn.metrics import accuracy_scoretrain_accuracy = accuracy_score(y_train, y_train_pred)test_accuracy = accuracy_score(y_test, y_test_pred)print('训练准确率为', train_accuracy)print('测试准确率为', test_accuracy)# 再次进行以测试另一组数据test_data = 'test.csv'test_data1 = pd.read_csv(test_data)test_data2 = pd.get_dummies(test_data1)test_data3 = test_data2.fillna(0.0)print(test_data2.shape)print (model.predict(test_data3))

回答:

似乎之前已经有人问过类似的问题,但最有效/最简单的方法是按照 Thibault Clement这里 描述的方法进行操作

# 获取训练测试中缺失的列missing_cols = set( X_train.columns ) - set( X_test.columns )# 在测试集中添加缺失的列,默认值设为0for c in missing_cols:    X_test[c] = 0# 确保测试集中的列顺序与训练集中的列顺序相同X_test = X_test[X_train.columns]

还值得注意的是,您的模型只能使用它训练时使用的特征,因此如果 X_test 相对于 X_train 有多余的列而不是少列,那么在预测之前必须删除这些多余的列。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注