在GridSearchCV中,如何只传递param_grid中的默认参数?

我是一个初学者,以下是我目前的代码。

from sklearn.naive_bayes import GaussianNBfrom sklearn.decomposition import PCApca = PCA()model = GaussianNB()steps = [('pca', pca), ('model', model)]pipeline = Pipeline(steps)cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)modelwithpca = GridSearchCV(pipeline, param_grid= ,cv=cv)modelwithpca.fit(X_train,y_train)

这是一个本地测试,我希望达到的目标是,

i. 对数据集进行主成分分析(PCA)

ii. 使用高斯朴素贝叶斯分类器,并仅使用其默认参数

iii. 使用分层随机分割(StratifiedShuffleSplit)

最终,我希望将上述步骤转移到另一个函数中,该函数将保存分类器、数据集和特征列表以测试性能。

dump_classifier_and_data(modelwithpca, dataset, features)  

在param_grid部分,我不想测试任何参数列表。我只想使用高斯朴素贝叶斯的默认参数,如果这样做有意义的话。我应该如何修改代码?

另外,是否应该对分类器对象的实例化方式进行任何更改?


回答:

GridSearchCV的目的是测试管道中至少一个组件的不同参数(如果你不想测试不同参数,你不需要使用GridSearchCV)。因此,一般来说,如果你想测试不同的PCAn_components值,使用管道和GridSearchCV的格式应如下所示:

gscv = GridSearchCV(pipeline, param_grid={'{step_name}__{parameter_name}': [possible values]}, cv=cv)

例如:

# 这将对pca的3个不同n_components值进行交叉验证gscv = GridSearchCV(pipeline, param_grid={'pca__n_components': [3, 6, 10]}, cv=cv)

如果你使用GridSearchCV来调整PCA如上所示,这当然意味着你的模型将使用默认值。

如果你不需要参数调整,那么GridSearchCV就不是合适的选择,因为像这样使用模型的默认参数来进行GridSearchCV,只会产生一个参数组合的网格,这就相当于仅执行交叉验证。这样做没有意义 – 如果我正确理解了你的问题:

from sklearn.naive_bayes import GaussianNBfrom sklearn.decomposition import PCAfrom sklearn.pipeline import Pipelinepca = PCA()model = GaussianNB()steps = [('pca', pca), ('model', model)]pipeline = Pipeline(steps)cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=42)# 获取你的模型的默认参数,并将它们用作param_gridmodelwithpca = GridSearchCV(pipeline, param_grid={'model__' + k: [v] for k, v in model.get_params().items()}, cv=cv)# 将根据你的cv配置运行5次modelwithpca.fit(X_train,y_train)

希望这对你有帮助,祝你好运!

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注