我想了解 fit() 函数的一般功能,尤其是在下面的代码片段中它具体做了什么。
我正在学习机器学习A-Z课程,因为我对机器学习还比较陌生(我刚开始学)。我知道一些基本概念术语,但不太了解技术部分。
代码1:
from sklearn.impute import SimpleImputermissingvalues = SimpleImputer(missing_values = np.nan, strategy = 'mean', verbose = 0) missingvalues = missingvalues.fit(X[:, 1:3])X[:, 1:3] = missingvalues.transform(X[:, 1:3])
另一个让我仍然有疑问的例子
代码2:
from sklearn.preprocessing import StandardScalersc_X = StandardScaler()print(sc_X)X_train = sc_X.fit_transform(X_train)print(X_train)X_test = sc_X.transform(X_test)
我认为如果我能了解这个函数的一般用途以及它通常做什么,我就可以继续了。但我当然也想知道它在这些代码中具体做了什么
回答:
这里还有一个不错的检查机会: https://scikit-learn.org/stable/tutorial/basic/tutorial.html
fit
方法在机器学习中总是用来学习某些东西的。
你通常会有以下步骤:
- 将你的数据分成两个或三个数据集
- 选取数据的一部分来学习/训练某些东西(通常是
X_train
)并使用fit
- 使用学到的算法对未见过的数据(通常是
X_test
)进行预测,使用predict
在你的第一个例子中:missingvalues.fit(X[:, 1:3])
你是在根据数据 X
训练 SimpleImputer
,其中你只使用了第 1,2,3
列,然后用 transform
将这个训练结果应用到这些数据上,进行覆盖。
在你的第二个例子中:你是在用 X_train
训练 StandardScaler
,并将这个训练结果应用到两个数据集 X_train
和 X_test
上,StandardScaler 从 X_train
中学习到,如果它学到10需要转换为2,那么它会在 X_train
和 X_test
两个集合中都将10转换为2。