vowpal-wabbit: 使用多次通过、保留集和保留周期来避免过拟合?

我想使用出色的 vowpal wabbit 工具 通过以下命令训练二元分类的神经网络进行分类任务:

vw --binary --nn 4 train.vw -f category.model

并测试它:

vw --binary -t -i category.model -p test.vw

但我的结果非常差(与我的线性SVM估计器相比)。

我发现一个评论提到应该使用 训练通过次数 参数(--passes arg)。

所以我的问题是如何确定训练通过次数以避免模型被重新训练

P.S. 我应该使用 holdout_period 参数吗?如何使用?


回答:

问题中的测试命令是错误的。它没有输入(-p ... 表示输出预测)。另外,不清楚你是想测试还是预测,因为它说的是 test,但命令中使用了 -p ...

测试 意味着你有带标签的数据,并且你在评估模型的质量。严格来说:预测 意味着你没有标签,因此你实际上无法知道你的预测有多好。实际上,你也可以在保留的、带标签的数据上进行预测(通过忽略标签假装没有标签),然后评估这些预测的质量,因为你实际上有标签。

一般来说:

  • 如果你想进行二元分类,你应该使用 {-1, 1} 的标签,并使用 --loss_function logistic--binary 是一个独立的选项,意味着你希望预测是二元的(这会给你提供较少的信息)。

  • 如果你已经有了一个单独的带标签的测试集,你不需要保留集。

vw 中的保留机制旨在替代测试集并避免过拟合,仅在使用多次通过时才相关,因为在单次通过中,所有 示例实际上都被保留了;每个下一个(尚未见过的)示例被视为 1)无标签的用于预测,2)带标签的用于测试和模型更新。换句话说:你的训练集实际上也是你的测试集。

所以你可以选择在训练集上进行多次通过而不使用保留集

 vw --loss_function logistic --nn 4 -c --passes 2 --holdout_off train.vw -f model

然后使用一个单独的、带标签的测试集来测试模型:

 vw -t -i model test.vw

或者在同一个训练集上进行多次通过,并使用一些保留集作为测试集

vw --loss_function logistic --nn 4 -c --passes 20 --holdout_period 7 train.vw -f model

如果你没有测试集,并且你想通过使用多次通过来增强拟合,你可以要求 vw 保留每 N 个示例中的一个(默认的 N 是 10,但你可以使用 --holdout_period <N> 明确地覆盖它,如上所示)。在这种情况下,你可以指定更高的通过次数,因为 vw 会在保留集上的损失开始增长时自动进行早期终止。

你会注意到你达到了早期终止,因为 vw 会打印类似这样的内容:

passes used = 5...average loss = 0.06074 h

表示在早期停止之前只使用了 N 次通过中的 5 次,保留的示例子集上的错误为 0.06074(尾随的 h 表示这是保留集的损失)。

如你所见,通过次数保留周期 是完全独立的选项。

为了改进并增加你对模型的信心,你可以使用其他优化方法,改变保留周期,尝试其他 --nn 参数。你可能还想检查 vw-hypersearch 工具(在 utl 子目录中)来帮助找到更好的超参数。

以下是使用 vw-hypersearch 在源码中包含的一个测试集上的示例:

$ vw-hypersearch 1 20 vw --loss_function logistic --nn % -c --passes 20 --holdout_period 11 test/train-sets/rcv1_small.dat --binarytrying 13 ............. 0.133333 (best)trying 8 ............. 0.122222 (best)trying 5 ............. 0.088889 (best)trying 3 ............. 0.111111trying 6 ............. 0.1trying 4 ............. 0.088889 (best)loss(4) == loss(5): 0.0888895       0.08888

表明 45 应该是 --nn 的好参数,在每 11 个示例中保留 1 个的子集上产生 0.08888 的损失。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注