为什么在使用transform方法得到相同输出时,我们还需要使用fit_transform方法?

我不明白为什么在transform方法可以提供与仅使用fit_transform方法相同的输出时,还必须使用fit_transform方法,fit方法的全部意义是什么?

我已经打印了x_trainx_test,它们都给出了相似的输出。

from sklearn.preprocessing import StandardScalersc = StandardScaler()x_train[:, 3:] = sc.fit_transform(x_train[:, 3:])x_test[:, 3:] = sc.transform(x_test[:, 3:])

回答:

在scikit-learn预处理器中,你通常总是会有fittransformfit_transform方法。

它们的区别如下:

fit方法会“学习”你的数据结构,以找出其中存在的类别和其他预处理信息。一旦你已经拟合了预处理器,你就可以使用这个已经拟合的预处理器来transform你的数据,使用那个“拟合”信息。让我们看一个简单的例子:

import numpy as np from sklearn.preprocessing import StandardScalerX_train = np.array([[1, 2], [3, 4], [5, 6]])X_test = np.array([[7, 8], [9, 10]])X_train:array([[1, 2],       [3, 4],       [5, 6]])X_test:array([[ 7,  8],       [ 9, 10]])

这里你正在准备一个标准化缩放器对象

sc = StandardScaler()

这个对象必须有一些参数来保存信息,比如数据的均值等。但由于它尚未看到任何数据,这个均值还不存在,所以下面的代码将显示一个错误

print(sc.mean_)AttributeError: 'StandardScaler' object has no attribute 'mean_'

现在让我们用它来拟合X_train数据

sc.fit(X_train)

让我们看看这个操作之后发生了什么

print(sc.mean_)[3. 4.]

现在我们可以看到我们的标准化缩放器对象已经计算了它看到的数据的均值,并将其存储在其属性之一,即mean_

所以这基本上就是fit方法的作用:它是用来查找关于一些数据的参数,在我们的案例中是训练数据。我们为什么要先找到这些参数,是因为我们可能希望准确地重用它们来转换其他数据。这就是transform方法的用武之地。

transform方法使用之前一些数据的“学习”参数来转换一些新数据。因此,在我们的案例中,我们现在可以转换我们的测试数据。这是因为训练和测试数据应该以相同的方式(使用相同的参数,如均值等)进行转换

sc.transform(X_test)array([[2.44949 , 2.44949 ],       [3.674235, 3.674235]])

但当然,我们也应该首先转换训练数据本身!

sc.transform(X_train)array([[-1.224745, -1.224745],       [ 0.      ,  0.      ],       [ 1.224745,  1.224745]])

如你所见,我们已经连续拟合然后转换了我们的训练数据,而我们只转换了测试数据,而不需要拟合它。连续拟合和转换就是fit_transform方法的用武之地。因此,对于训练数据,我们可以直接这样做:

X_train = sc.fit_transform(X_train)array([[-1.224745, -1.224745],       [ 0.      ,  0.      ],       [ 1.224745,  1.224745]])

这个方法先拟合数据,然后转换它。但你不能在没有拟合数据的情况下就转换数据。现在你已经使用fit_transform或只是fit拟合了你的训练数据,现在你可以只用与训练数据相同的拟合信息来转换你的测试数据。

希望这些解释足够清楚。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注