使用h2o.gbm()进行重复运行

我需要使用h2o.gbm函数进行重复运行,并且希望在使用相同超参数的情况下得到不同的结果。

尽管我已经创建了一个循环,为每个配置提供双重运行,并且使用h2o.performance函数提取了这些h2o gbm模型运行的结果;但我刚刚发现每次双重运行的结果完全相同。

你有什么建议让我在使用相同超参数运行两个h2o.gbm模型时得到不同的结果吗?

我尝试过以下方法:

  1. 尝试过使用不同的nthreads进行h2o.shutdown和h2o.init
  2. 更改和删除h2o.gbm中的seed参数
  3. 删除score_tree_interval和stopping_round参数

所有这些尝试都失败了,相同超参数的两次运行得到了完全相同的结果。此外,我分享了一个样本超参数配置,希望通过运行它两次得到不同的结果。

h2o.gbm(x = x_col_names, y = y,         training_frame = train_h2o,         fold_column = "index_4seasons",        ntrees = 1000,         max_depth = 5,         learn_rate = 0.1,         stopping_rounds = 5,         score_tree_interval = 10,         seed = 1)

任何帮助和评论都将不胜感激。


回答:

种子值会略微改变结果。请看下面的示例,演示了在使用文档中的示例时MSE会发生变化。

# 导入前列腺数据集到H2O:train_h2o = h2o.import_file("http://s3.amazonaws.com/h2o-public-test-data/smalldata/prostate/prostate.csv")# 设置预测变量和响应变量;设置因子:train_h2o["CAPSULE"] = train_h2o["CAPSULE"].asfactor()x_col_names = ["ID","AGE","RACE","DPROS","DCAPS","PSA","VOL","GLEASON"]y = "CAPSULE"# 构建并训练第一个模型:pros_gbm1 = H2OGradientBoostingEstimator(    nfolds = 5, ntrees = 1000, max_depth = 5, learn_rate = 0.1,     stopping_rounds = 5, score_tree_interval = 10, seed = 1)pros_gbm1.train(x = x_col_names, y = y,                 training_frame = train_h2o)# 构建并训练第二个模型,仅更改种子数:pros_gbm2 = H2OGradientBoostingEstimator(    nfolds = 5, ntrees = 1000, max_depth = 5, learn_rate = 0.1,     stopping_rounds = 5, score_tree_interval = 10, seed = 123456789)pros_gbm2.train(x = x_col_names, y = y,                 training_frame = train_h2o)print('模型1 MSE:', pros_gbm1.mse())print('模型2 MSE:', pros_gbm2.mse())

输出

模型1 MSE: 0.020725291770552916模型2 MSE: 0.02189654172905499

如果你的数据集在使用不同的种子和硬件设置时仍然产生可复现的结果,可能是因为数据集不够大或不够复杂,无法使模型表现出随机性。你还可以尝试更改fold_column中的折叠设置,看看是否有影响。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注