我正在编写一个类,用于在不同的缩放器之间切换。以下代码“有效”(但不切换缩放器):
from sklearn.preprocessing import StandardScaler, MinMaxScalerclass CustomTransformer(StandardScaler, MinMaxScaler): def __init__(self, which,with_std=True,with_mean=True, feature_range=(0,1)): self.which = which self.with_mean = with_mean self.with_std = with_std self.feature_range = feature_range if which=="standard": self = StandardScaler.__init__(self) else: self = MinMaxScaler.__init__(self)X = [[1,2,3],[3,4,5],[6,7,8]]ct = CustomTransformer(which="standard") ct.fit_transform(X)array([[-1.13554995, -1.13554995, -1.13554995], [-0.16222142, -0.16222142, -0.16222142], [ 1.29777137, 1.29777137, 1.29777137]])ct = CustomTransformer(which="")ct.fit_transform(X)array([[-1.13554995, -1.13554995, -1.13554995], [-0.16222142, -0.16222142, -0.16222142], [ 1.29777137, 1.29777137, 1.29777137]])
所以我的问题更像是理论性的:
在
scikit-learn
中,如何正确地实现条件多类继承并切换缩放器?
回答:
这“仅仅”是有效的:
from sklearn.base import TransformerMixinfrom sklearn.preprocessing import StandardScaler, MinMaxScalerX = [[1,2,3],[3,4,5],[6,7,8]]class CustomTransformer(TransformerMixin): def __init__(self, condition,with_mean=True, with_std=True, feature_range=(0,1), **kwargs): self.condition = condition if condition: self.scaler = StandardScaler(with_mean=with_mean, with_std=with_std, **kwargs) else: self.scaler = MinMaxScaler(feature_range=feature_range, **kwargs) def fit(self, X): return self.scaler.fit(X) def transform(self, X): return self.scaler.transform(X) def get_params(self): d = self.scaler.get_params() d['condition'] = self.condition return d
ct = CustomTransformer(False, feature_range=(0,.1))ct.fit_transform(X)array([[0. , 0. , 0. ], [0.04, 0.04, 0.04], [0.1 , 0.1 , 0.1 ]])
ct = CustomTransformer(True, feature_range=(0,.1))ct.fit_transform(X)array([[-1.13554995, -1.13554995, -1.13554995], [-0.16222142, -0.16222142, -0.16222142], [ 1.29777137, 1.29777137, 1.29777137]])
现在这个CustomTransformer可以通过.get_params()
被GridSearchCV
访问:
from sklearn.model_selection import GridSearchCVgs = GridSearchCV(ct, param_grid={})gs.get_params(){'cv': None, 'error_score': nan, 'estimator__copy': True, 'estimator__with_mean': True, 'estimator__with_std': True, 'estimator__condition': True, 'estimator': <__main__.CustomTransformer at 0x7fbd8d3aa9d0>, 'iid': 'deprecated', 'n_jobs': None, 'param_grid': {}, 'pre_dispatch': '2*n_jobs', 'refit': True, 'return_train_score': False, 'scoring': None, 'verbose': 0}