### 在PyTorch中为随机权重平均设置学习率

以下是从这里获取的PyTorch中随机权重平均的一个小型工作代码示例。

loader, optimizer, model, loss_fn = ...swa_model = torch.optim.swa_utils.AveragedModel(model)scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=300)swa_start = 160swa_scheduler = SWALR(optimizer, swa_lr=0.05)for epoch in range(300):    for input, target in loader:        optimizer.zero_grad()        loss_fn(model(input), target).backward()        optimizer.step()        if epoch > swa_start:            swa_model.update_parameters(model)            swa_scheduler.step()        else:            scheduler.step()    # 在末尾更新swa_model的bn统计数据    torch.optim.swa_utils.update_bn(loader, swa_model)    # 使用swa_model对测试数据进行预测    preds = swa_model(test_input)

在这个代码中,在第160个epoch之后,使用了swa_scheduler而不是常规的scheduler。那么swa_lr代表什么呢?文档中提到,

通常,在SWA中,学习率被设置为一个较高的常数值。SWALR是一种学习率调度器,它将学习率调整到一个固定值,然后保持不变。

  1. 那么在第160个epoch之后,optimizer的学习率会发生什么变化?
  2. swa_lr是否会影响optimizer的学习率?

假设在代码的开始,optimizer是用学习率1e-4初始化的ADAM。那么上述代码是否意味着在前160个epoch中,训练的学习率将是1e-4,然后在剩余的epoch中将是swa_lr=0.05?如果是的话,将swa_lr也定义为1e-4是否是一个好主意?


回答:

  • 上述代码是否意味着在前160个epoch中,训练的学习率将是1e-4

    不会是1e-4,在前160个epoch中,学习率是由第一个调度器scheduler管理的。这个调度器被初始化为torch.optim.lr_scheduler.CosineAnnealingLR。学习率将遵循以下曲线:

    enter image description here


  • 在剩余的epoch中将是swa_lr=0.05

    这部分是对的,在第二部分 – 从第160个epoch开始 – 优化器的学习率将由第二个调度器swa_scheduler管理。这个调度器被初始化为torch.optim.swa_utils.SWALR。您可以在文档页面上阅读到:

    SWALR是一种学习率调度器,它将学习率调整到一个固定值[swa_lr],然后保持不变

    默认情况下(参见源代码),调整前的epoch数等于10。因此,从第170个epoch到第300个epoch,学习率将等于swa_lr,并将保持不变。第二部分将是:

    enter image description here

    这个完整的配置,即两部分:

    enter image description here


  • 如果是的话,将swa_lr也定义为1e-4是否是一个好主意

    文档中提到:

    通常,在SWA中,学习率被设置为一个较高的常数值。

    swa_lr设置为1e-4将导致以下学习率配置文件:

    enter image description here

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注