scikit learn: 自定义分类器与GridSearchCV兼容

我已经实现了自己的分类器,现在我想对其进行网格搜索,但出现了以下错误: estimator.fit(X_train, y_train, **fit_params)TypeError: fit() takes 2 positional arguments but 3 were given

我按照这个教程并使用了这个模板,这是由scikit的官方文档提供的。我的类定义如下:

class MyClassifier(BaseEstimator, ClassifierMixin):    def __init__(self, lr=0.1):        self.lr=lr    def fit(self, X, y):        # 一些代码        return self    def predict(self, X):        # 一些代码        return y_pred    def get_params(self, deep=True)        return {'lr':self.lr}    def set_params(self, **parameters):        for parameter, value in parameters.items():            setattr(self, parameter, value)        return self

我尝试通过以下方式进行网格搜索:

params = {    'lr': [0.1, 0.5, 0.7]}gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

编辑 I

这是我调用的方式: gs.fit([‘hello world’, ‘trying’,’hello world’, ‘trying’, ‘hello world’, ‘trying’, ‘hello world’, ‘trying’], [‘I’, ‘Z’, ‘I’, ‘Z’, ‘I’, ‘Z’, ‘I’, ‘Z’])

结束编辑 I

错误是由_fit_and_score方法在文件python3.5/site-packages/sklearn/model_selection/_validation.py中产生的

它调用了estimator.fit(X_train, y_train, **fit_params),传入了3个参数,但我的分类器只有两个参数,所以这个错误是有道理的,但我不知道如何解决… 我也尝试在fit方法中添加一些虚拟参数,但没有效果。

编辑 II

完整的错误输出:

Traceback (most recent call last):  File "/home/rodrigo/no_version/text_classifier/MyClassifier.py", line 355, in <module>    ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_search.py", line 639, in fit    cv.split(X, y, groups)))  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__    while self.dispatch_one_batch(iterator):  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch    self._dispatch(tasks)  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch    job = self._backend.apply_async(batch, callback=cb)  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async    result = ImmediateResult(func)  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__    self.results = batch()  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__    return [func(*args, **kwargs) for func, args, kwargs in self.items]  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>    return [func(*args, **kwargs) for func, args, kwargs in self.items]  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_validation.py", line 458, in _fit_and_score    estimator.fit(X_train, y_train, **fit_params)TypeError: fit() takes 2 positional arguments but 3 were given

结束编辑 II

已解决感谢大家,我犯了一个愚蠢的错误:有两个同名的函数(fit),(我为自定义目的实现了另一个带有不同参数的函数,只要我重命名了我的’自定义fit’,它就正常工作了。)

谢谢你们,对不起


回答:

以下代码对我来说是有效的:

class MyClassifier(BaseEstimator, ClassifierMixin):     def __init__(self, lr=0.1):         self.lr = lr         # 一些代码         pass     def fit(self, X, y):         # 一些代码         pass     def predict(self, X):         # 一些代码         return X % 3params = {    'lr': [0.1, 0.5, 0.7]}gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)x = np.arange(30)y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))gs.fit(x, y)

我能想到的最好的解释是,你在gs.fit方法中传递了除xy之外的其他东西,或者你的MyClassifier.fit方法缺少了self参数。

fit_params kwargs 只有在你向gs.fit方法传递关键字参数时才会被填充,否则它是一个空字典({}),并且**fit_params不会抛出参数错误。为了测试这一点,创建你的分类器的一个实例并传递**{}。例如:

clf = MyClassifier()clf.fit(x, y, **{})

这不会抛出位置参数错误。

因此,除非向gs.fit传递了某些东西,例如gs.fit(x, y, some_arg=123),否则我认为你在MyClassifier.fit的定义中缺少了一个位置参数。你包含的错误消息似乎支持这一假设,因为它指出fit() takes 2 positional arguments but 3 were given。如果你定义fit如下,它将接受3个位置参数:

def fit(self, X, y): ...

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注