在使用PyTorch higher库进行MAML时,什么时候应该调用.eval()和.train()?

我在查看omniglot maml示例时发现,他们在测试代码的开头使用了net.train()。这似乎是个错误,因为这意味着在元测试时每个任务的统计数据会被共享:

def test(db, net, device, epoch, log):    # 重要的是,在我们的测试过程中,我们*不*对模型进行微调,以简化操作。    # 大多数使用MAML进行此任务的研究论文在此处会进行额外的微调阶段,    # 如果您将此代码用于研究,应添加该阶段。    net.train()    n_test_iter = db.x_test.shape[0] // db.batchsz    qry_losses = []    qry_accs = []    for batch_idx in range(n_test_iter):        x_spt, y_spt, x_qry, y_qry = db.next('test')        task_num, setsz, c_, h, w = x_spt.size()        querysz = x_qry.size(1)        # TODO: 或许可以将此部分抽取为一个独立的模块,        # 这样就不必在`train`和`test`之间重复。        n_inner_iter = 5        inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1)        for i in range(task_num):            with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt):                # 通过对模型参数进行梯度步长来优化支持集的可能性。                # 这将模型的元参数适应于任务。                for _ in range(n_inner_iter):                    spt_logits = fnet(x_spt[i])                    spt_loss = F.cross_entropy(spt_logits, y_spt[i])                    diffopt.step(spt_loss)                # 这些参数引发的查询损失和准确率。                qry_logits = fnet(x_qry[i]).detach()                qry_loss = F.cross_entropy(                    qry_logits, y_qry[i], reduction='none')                qry_losses.append(qry_loss.detach())                qry_accs.append(                    (qry_logits.argmax(dim=1) == y_qry[i]).detach())    qry_losses = torch.cat(qry_losses).mean().item()    qry_accs = 100. * torch.cat(qry_accs).float().mean().item()    print(        f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}'    )    log.append({        'epoch': epoch + 1,        'loss': qry_losses,        'acc': qry_accs,        'mode': 'test',        'time': time.time(),    })

然而,每当我使用eval时,我的MAML模型就会发散(尽管我的测试是在mini-imagenet上进行的):

>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5942, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)eval_loss=0.9859228551387786, eval_acc=0.5907692521810531args.meta_learner.lr_inner=0.01==== in forward2>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(171440.6875, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(208426.0156, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(17067344., grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(40371.8125, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.0911e+11, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(21.3515, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(5.4257e+13, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(128.9109, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(3994.7734, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1682896., grad_fn=<NormBackward1>)eval_loss_sanity=nan, eval_acc_santiy=0.20000000298023224

那么,应该采取什么措施来避免这种发散现象呢?

注意事项:

  • 重新训练非常昂贵。对我来说,使用MAML训练一个5层CNN需要18天。分布式解决方案在这里会非常有帮助 https://github.com/learnables/learn2learn/issues/170
  • 或许在训练过程中仅使用train(即使在训练过程中进行评估可能是个好主意,以便将批次统计数据保存到检查点中)
  • 或者下次从一开始就使用批次统计数据进行训练

相关链接:


回答:

简而言之:使用mdl.train(),因为它使用批次统计数据(但推理将不再是确定性的)。 在元学习中,您可能不希望使用mdl.eval()


批量归一化的预期行为:

  • 重要的是,在推理(eval/测试)期间使用的是running_mean和running_std,这些是从训练中计算得出的(因为他们想要一个确定的输出,并使用总体统计数据的估计)。
  • 在训练期间使用批次统计数据,但通过运行平均值来估计总体统计数据。我认为在训练期间使用批次统计数据是为了引入噪声,从而正则化训练(噪声鲁棒性)。
  • 在元学习中,我认为在测试期间使用批次统计数据(而不计算运行均值)是最好的,因为我们应该看到新的任务/分布。付出的代价是失去确定性。出于好奇,使用从元训练中估计的总体统计数据的准确性会很有趣。

这可能就是为什么我在使用mdl.train()进行测试时没有看到发散现象的原因。

因此,请确保使用mdl.train()(因为它使用批次统计数据 https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d),但不要保存或稍后使用那些作弊的新运行统计数据。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注