为什么 sklearn 的 cross_validate() 需要重新拟合?

我理解像 GridSearchCV 这样的工具为什么需要重新拟合。它会在不同超参数值范围内进行探索,并在比较得分后,使用最佳找到的参数在整个数据集上重新拟合一个估计器。

但是,尽管这有道理,我的疑问是关于 cross_validate 类的,它只使用一组超参数。我理解其目的是为了观察模型在不同训练/测试分割折叠上的泛化能力。为什么这里需要重新拟合呢?

我理解为什么在 n 个数据折叠上会进行 n 次拟合。但根据文档,还会进行一次重新拟合,如 error_score 参数中所讨论的:

error_score : ‘raise’ 或数值

如果在估计器拟合过程中发生错误,则分配给得分的值。如果设置为‘raise’,则会引发错误。如果给定数值,则会引发 FitFailedWarning。此参数不影响重新拟合步骤,重新拟合步骤总是会引发错误。

因此,除了 n 次拟合之外,还会有一次额外的拟合,我不明白这是为什么。这个类没有预测方法,所以即使它以某种方式区分了模型并选择了一个“最佳”模型(尽管它们都有完全相同的超参数),重新拟合也没有意义。

为了证明这一点,我创建了一个 MLPRegressor 模型,我知道结合我的数据集会导致梯度爆炸:

DL = MLPRegressor(        hidden_layer_sizes=(200, 200, 200), activation='relu', max_iter=16,            solver='sgd', learning_rate='invscaling', power_t=0.9)DL.fit(df_training[predictor_cols], df_training[target_col])

模型能够无错误地拟合(证明我的数据集中没有 NaN 或 inf 值),但确实给出了警告:

RuntimeWarning: overflow encountered in matmul

这证明了梯度爆炸,因此任何预测的输出都是 NaN。

根据我对 cross_validate 文档的理解,如果我传递以下内容(使用 error_score=1):

DL = MLPRegressor(        hidden_layer_sizes=(200, 200, 200), activation='relu', max_iter=16,            solver='sgd', learning_rate='invscaling', power_t=0.9)DL_CV = cross_validate(DL, df_training[predictor_cols], y=df_training[target_col], cv=None, n_jobs=1, pre_dispatch=5, return_train_score=False, return_estimator=True, error_score=1)

我应该会收到 ‘FitFailedWarning’ 消息,但不会有错误。然而,训练并未完成,反而引发了以下错误:

ValueError: Input contains NaN, infinity or a value too large for dtype(‘float64’).

因此,我得出结论,这个错误是由于重新拟合造成的,但我不知道重新拟合的目的是什么…


回答:

cross_validate 实际上并不重新拟合,你可以从源代码中验证这一点。文档是错误的,可能是从 GridSearchCV 的文档中复制过来的。你应该开启一个 Issue 或提交一个拉取请求;如果你不愿意,我可以帮你做。

不过,我不知道你最后的错误来源;也许错误是在成功拟合的模型评分时引发的,而不是在拟合过程中?如果原始拟合只引发警告,那么在搜索过程中默认是不会捕获到这个警告的。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注