为什么 sklearn Pipeline 调用 transform() 的次数远多于 fit()?

经过大量阅读和检查在不同 verbose 参数设置下的 pipeline.fit() 操作后,我仍然对我的 pipeline 访问某个步骤的 transform 方法如此频繁感到困惑。

下面是一个简单的 pipeline 示例,使用 GridSearchCV 进行 fit,采用 3 折交叉验证,但参数网格只有一组超参数。因此,我预期 pipeline 会运行三次。预期 step1step2 都会被调用 fit 三次,但每个步骤的 transform 被调用的次数却多得多。这是为什么呢?下面是简化的代码示例和日志输出。

# 导入库import pandas as pdfrom sklearn import datasetsfrom sklearn.model_selection import KFoldfrom sklearn.linear_model import LogisticRegressionfrom sklearn.base import TransformerMixin, BaseEstimatorfrom sklearn.pipeline import Pipeline# 加载示例数据iris = datasets.load_iris()X = pd.DataFrame(iris.data, columns = iris.feature_names)y = pd.Series(iris.target, name='y')# 定义几个简单的 pipeline 步骤class mult_everything_by(TransformerMixin, BaseEstimator):    def __init__(self, multiplier=2):        self.multiplier = multiplier    def fit(self, X, y=None):        print "Fitting step 1"        return self    def transform(self, X, y=None):        print "Transforming step 1"        return X* self.multiplierclass do_nothing(TransformerMixin, BaseEstimator):    def __init__(self, meaningless_param = 'hello'):        self.meaningless_param=meaningless_param    def fit(self, X, y=None):        print "Fitting step 2"        return self    def transform(self, X, y=None):        print "Transforming step 2"        return X# 定义 Pipeline 中的步骤pipeline_steps = [('step1', mult_everything_by()),                  ('step2', do_nothing()),                   ('classifier', LogisticRegression()),                  ]pipeline = Pipeline(pipeline_steps)# 为了保持这个示例非常简洁,这个参数网格只有一个超参数组,# 因此我们只拟合一种模型类型param_grid = {'step1__multiplier': [2],   #,3],              'step2__meaningless_param': ['hello']   #, 'howdy', 'goodbye']              }# 定义模型搜索过程/对象# (拟合一个模型,由于 3 折交叉验证会进行 3 次拟合)cv_model_search = GridSearchCV(pipeline,                                param_grid,                                cv = KFold(3),                               refit=False,                                verbose = 0) # 拟合模型搜索对象中定义的所有(1)模型cv_model_search.fit(X,y)

输出:

Fitting step 1Transforming step 1Fitting step 2Transforming step 2Transforming step 1Transforming step 2Transforming step 1Transforming step 2Fitting step 1Transforming step 1Fitting step 2Transforming step 2Transforming step 1Transforming step 2Transforming step 1Transforming step 2Fitting step 1Transforming step 1Fitting step 2Transforming step 2Transforming step 1Transforming step 2Transforming step 1Transforming step 2

回答:

因为你使用了 GridSearchCVcv = KFold(3),这将对你的模型进行交叉验证。以下是发生的情况:

  1. 它会将数据分成两部分:训练集和测试集。
  2. 对于训练集,它会拟合并转换 pipeline 的每个部分(不包括最后一个,即分类器)。这就是为什么你会看到 fit step1, transform step1, fit step2, transform step2
  3. 它会用转换后的数据拟合分类器(这部分在你的输出中没有显示)。
  4. 已编辑 现在进入评分部分。我们不希望再次拟合这些部分。我们将使用之前拟合过程中学到的信息。因此,pipeline 的每个部分只会调用 transform()。这就是 Transforming step 1, Transforming step 2 出现的原因。

    它显示了两次是因为在 GridSearchCV 中,默认行为是计算训练数据和测试数据的得分。这个行为由 return_train_score 控制。你可以设置 return_train_score=False,这样就只会看到一次。

  5. 转换后的测试数据将用于从分类器中预测输出。(同样,测试数据上没有拟合,只有预测或转换)。

  6. 预测值将用于与实际值进行比较以评分模型。
  7. 步骤 1-6 将重复 3 次 (KFold(3))
  8. 现在看一下你的参数:

    param_grid = {‘step1__multiplier’: [2], #,3], ‘step2__meaningless_param’: [‘hello’] #, ‘howdy’, ‘goodbye’] }

    展开后,只有一个组合,即:

    组合1: ‘step1__multiplier’=2, ‘step2__meaningless_param’ = ‘hello’

    如果你提供了更多选项,你已经注释掉的更多组合将是可能的,比如:

    组合1: ‘step1__multiplier’=2, ‘step2__meaningless_param’ = ‘hello’

    组合2: ‘step1__multiplier’=3, ‘step2__meaningless_param’ = ‘hello’

    组合3: ‘step1__multiplier’=2, ‘step2__meaningless_param’ = ‘howdy’

    依此类推…

  9. 步骤 1-7 将针对每种可能的组合重复进行。

  10. 在交叉验证的测试折叠上平均得分最高的组合将被选中,最终用完整数据拟合模型(没有再分成训练集和测试集)。
  11. 但是你设置了 refit=False。因此模型不会再次拟合。否则你会看到更多的输出,如下所示:

    Fitting step 1Transforming step 1Fitting step 2Transforming step 2

希望这能澄清你的疑问。如有更多问题,请随时提问。

Related Posts

L1-L2正则化的不同系数

我想对网络的权重同时应用L1和L2正则化。然而,我找不…

使用scikit-learn的无监督方法将列表分类成不同组别,有没有办法?

我有一系列实例,每个实例都有一份列表,代表它所遵循的不…

f1_score metric in lightgbm

我想使用自定义指标f1_score来训练一个lgb模型…

通过相关系数矩阵进行特征选择

我在测试不同的算法时,如逻辑回归、高斯朴素贝叶斯、随机…

可以将机器学习库用于流式输入和输出吗?

已关闭。此问题需要更加聚焦。目前不接受回答。 想要改进…

在TensorFlow中,queue.dequeue_up_to()方法的用途是什么?

我对这个方法感到非常困惑,特别是当我发现这个令人费解的…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注