在skit中实现管道转换器时对象不可迭代

我有一组字符串需要进行分类。我使用了一个管道对象。

我实现了两个虚拟转换器:一个将数据转换为特定格式(以便被另一个转换器接受),另一个将数据再次转换回其原始形式(一种逆转换)。

X和y是字符串列表,假设X=['London is great', 'London is beautiful', 'I hate London']y=['p','p','n']。我希望X被转换为字符串的列表列表:X=[['London is great'], ['London is beautiful'], ['I hate London']]

我的代码如下:

from sklearn.naive_bayes import MultinomialNBfrom sklearn.feature_extraction.text import CountVectorizerfrom sklearn.feature_selection import SelectKBest, chi2from sklearn.pipeline import Pipelinefrom sklearn.base import TransformerMixin, BaseEstimatorvectorizer = CountVectorizer(input=u'content',                             analyzer=u'word',                             lowercase=True,                             stop_words=cached_stopwords,                             strip_accents=u'unicode',                             ngram_range=(1, 3), binary=False)estimators = [('pre_ds', PreprocessPreDS()),              ('post_ds', PreprocesarPostDS()),              ('vectorizer', vectorizer),              ('feature_selector', SelectKBest(chi2, k=100)),              ('clf', MultinomialNB())]  # create the pipelinepipe = Pipeline(estimators)pipe.fit(X_train, y_train)

其中我的自定义转换器如下:

class PreprocessPreDS(BaseEstimator, TransformerMixin):    def __init__(self):        pass    def transform(self, X, *_):        return [[x] for x in X]    def fit(self, *_):        return self    def fit_transform(self, X, y=None, **fit_params):        return self.fit(X)    def get_params(self, deep=True):        """        :param deep: ignored, as suggested by scikit learn's documentation        :return: dict containing each parameter from the model as name and its current value        """        return {}    def set_params(self, **parameters):        """        set all parameters for current objects        :param parameters: dict containing its keys and values to be initialised        :return: self        """        for parameter, value in parameters.items():            setattr(self, parameter, value)        return selfclass PreprocesarPostDS(BaseEstimator, TransformerMixin):    def __init__(self):        pass    def transform(self, X, *_):        return [x[0] for x in X]    def fit(self, *_):        return self    def fit_transform(self, X, y=None, **fit_params):        return self.fit(X)    def get_params(self, deep=True):        """        :param deep: ignored, as suggested by scikit learn's documentation        :return: dict containing each parameter from the model as name and its current value        """        return {}    def set_params(self, **parameters):        """        set all parameters for current objects        :param parameters: dict containing its keys and values to be initialised        :return: self        """        for parameter, value in parameters.items():            setattr(self, parameter, value)        return self

当我运行这段代码时,我得到了以下错误:

    Traceback (most recent call last):  File "/home/rodrigo/nb/train_nb_pipeline.py", line 449, in <module>    process(args.label, args.evaluate, args.label_all, corpus=args.corpus_path)  File "/home/rodrigo/nb/train_nb_pipeline.py", line 179, in process    pipe.fit(X_train, y_train)  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/pipeline.py", line 248, in fit    Xt, fit_params = self._fit(X, y, **fit_params)  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/pipeline.py", line 213, in _fit    **fit_params_steps[name])  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/externals/joblib/memory.py", line 362, in __call__    return self.func(*args, **kwargs)  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/pipeline.py", line 581, in _fit_transform_one    res = transformer.fit_transform(X, y, **fit_params)  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/feature_extraction/text.py", line 869, in fit_transform    self.fixed_vocabulary_)  File "/home/rodrigo/.env/lib/python3.5/site-packages/sklearn/feature_extraction/text.py", line 790, in _count_vocab    for doc in raw_documents:TypeError: 'PreprocessPostDS' object is not iterable

然而,如果我从estimators中排除('pre_ds', PreprocessPreDS())('post_ds', PreprocesarPostDS()),它就能正常运行。


回答:

将以下内容更改为:

def fit_transform(self, X, y=None, **fit_params):    return self.fit(X)

改为:

def fit_transform(self, X, y=None, **fit_params):    return self.fit(X).transform(X)

在上面的代码中,你实际上是返回selfself是类对象(在本例中是PreprocessPreDS和PreprocessPostDS)。fit_transform()应该返回转换后的数据,而不是类对象。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注