如何正确使用模型解释器处理未见数据?

我使用管道训练了我的分类器:

param_tuning = {        'classifier__learning_rate': [0.01, 0.1],        'classifier__max_depth': [3, 5, 7, 10],        'classifier__min_child_weight': [1, 3, 5],        'classifier__subsample': [0.5, 0.7],        'classifier__n_estimators' : [100, 200, 500],    }cat_pipe = Pipeline(    [        ('selector', ColumnSelector(categorical_features)),        ('encoder', ce.one_hot.OneHotEncoder())    ])num_pipe = Pipeline(    [        ('selector', ColumnSelector(numeric_features)),        ('scaler', StandardScaler())    ])preprocessor = FeatureUnion(    transformer_list=[        ('cat', cat_pipe),        ('num', num_pipe)    ])xgb_pipe = Pipeline(    steps=[        ('preprocessor', preprocessor),        ('classifier', xgb.XGBClassifier())    ])grid = GridSearchCV(xgb_pipe, param_tuning, cv=5, n_jobs=-1, scoring='accuracy')xgb_model = grid.fit(X_train, y_train)

训练数据包含分类数据,因此转换后的数据形状为(x , 100 )。之后,我尝试解释模型对未见数据的预测。由于我直接将单个未见示例传递给模型,它以形状(x, 15)进行了预处理(因为单个观察没有所有分类数据的全部示例)。

eli5.show_prediction(xgb['classifier'], xgb['preprocessor'].fit_transform(df), columns = xgb['classifier'].get_booster().feature_names))

然后我得到了

ValueError: Shape of passed values is (1, 15), indices imply (1, 100).

这是因为模型是在形状为(x, 100)的整个预处理数据集上训练的,但我传递给解释器的是形状为(1,15)的单个观察。如何正确地将未见的单个观察传递给解释器?


回答:

我们绝不未见数据使用.fit_transform();正确的方法是使用已经用你的训练数据拟合好的预处理器的.transform()方法(这里是xgb['preprocessor'])。这样,我们确保(转换后的)未见数据与我们的(转换后的)训练数据具有相同的特征,因此它们与使用后者构建的模型兼容。

所以,你应该在这里将.fit_transform(df)替换为.transform(df)

eli5.show_prediction(xgb['classifier'], xgb['preprocessor'].fit_transform(df), columns = xgb['classifier'].get_booster().feature_names))

替换为:

eli5.show_prediction(xgb['classifier'], xgb['preprocessor'].transform(df), columns = xgb['classifier'].get_booster().feature_names))

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注