我需要创建一个自定义变换器以输入到评分器中。
评分器将字典列表传递给我的估计器的predict或predict_proba方法,而不是DataFrame。这意味着模型必须同时处理这两种数据类型。因此,我需要提供一个自定义的ColumnSelectTransformer来替代scikit-learn自带的ColumnTransformer。
这是我为自定义变换器编写的代码,目的是删除所提供列中的空值。
simple_cols = ['BEDCERT', 'RESTOT', 'INHOSP', 'CCRC_FACIL', 'SFF', 'CHOW_LAST_12MOS', 'SPRINKLER_STATUS', 'EXP_TOTAL', 'ADJ_TOTAL']class ColumnSelectTransformer(BaseEstimator, TransformerMixin): def __init__(self, columns): self.columns = columns def fit(self, X, y=None): return self def transform(self, X): if not isinstance(X, pd.DataFrame): X = pd.DataFrame(X) X.dropna(inplace=True) return X[self.columns].values()simple_features = Pipeline([ ('cst', ColumnSelectTransformer(simple_cols)),])
然而,我无法通过以下断言测试
assert data['RESTOT'].isnull().sum() > 0assert not np.isnan(simple_features.fit_transform(data)).any()
我生成了一个类型错误
---------------------------------------------------------------------------TypeError Traceback (most recent call last)<ipython-input-44-922f08231b1f> in <module>() 1 assert not data['RESTOT'].isnull().sum() > 0----> 2 assert not np.isnan(simple_features.fit_transform(data)).any()/opt/conda/lib/python3.7/site-packages/sklearn/pipeline.py in fit_transform(self, X, y, **fit_params) 391 return Xt 392 if hasattr(last_step, 'fit_transform'):--> 393 return last_step.fit_transform(Xt, y, **fit_params) 394 else: 395 return last_step.fit(Xt, y, **fit_params).transform(Xt)/opt/conda/lib/python3.7/site-packages/sklearn/base.py in fit_transform(self, X, y, **fit_params) 551 if y is None: 552 # fit method of arity 1 (unsupervised transformation)--> 553 return self.fit(X, **fit_params).transform(X) 554 else: 555 # fit method of arity 2 (supervised transformation)<ipython-input-42-e20ea4310864> in transform(self, X) 12 X = pd.DataFrame(X) 13 X.dropna(inplace=True)---> 14 return X[self.columns].values() 15 16 simple_features = Pipeline([TypeError: 'numpy.ndarray' object is not callable
如果有人需要访问,这里是实际数据。
%%bashmkdir datawget http://dataincubator-wqu.s3.amazonaws.com/mldata/providers-train.csv -nc -P ./ml-datawget http://dataincubator-wqu.s3.amazonaws.com/mldata/providers-metadata.csv -nc -P ./ml-datadata = pd.read_csv('./ml-data/providers-train.csv', encoding='latin1')
回答:
正如日志所指出的,错误出在X[self.columns].values()
。 values
是一个numpy数组,因此你不能将其作为方法调用(在后面加括号)。你应该尝试X[self.columns].values
。