我在查看omniglot maml示例时发现,他们在测试代码的开头使用了net.train()
。这似乎是个错误,因为这意味着在元测试时每个任务的统计数据会被共享:
def test(db, net, device, epoch, log): # 重要的是,在我们的测试过程中,我们*不*对模型进行微调,以简化操作。 # 大多数使用MAML进行此任务的研究论文在此处会进行额外的微调阶段, # 如果您将此代码用于研究,应添加该阶段。 net.train() n_test_iter = db.x_test.shape[0] // db.batchsz qry_losses = [] qry_accs = [] for batch_idx in range(n_test_iter): x_spt, y_spt, x_qry, y_qry = db.next('test') task_num, setsz, c_, h, w = x_spt.size() querysz = x_qry.size(1) # TODO: 或许可以将此部分抽取为一个独立的模块, # 这样就不必在`train`和`test`之间重复。 n_inner_iter = 5 inner_opt = torch.optim.SGD(net.parameters(), lr=1e-1) for i in range(task_num): with higher.innerloop_ctx(net, inner_opt, track_higher_grads=False) as (fnet, diffopt): # 通过对模型参数进行梯度步长来优化支持集的可能性。 # 这将模型的元参数适应于任务。 for _ in range(n_inner_iter): spt_logits = fnet(x_spt[i]) spt_loss = F.cross_entropy(spt_logits, y_spt[i]) diffopt.step(spt_loss) # 这些参数引发的查询损失和准确率。 qry_logits = fnet(x_qry[i]).detach() qry_loss = F.cross_entropy( qry_logits, y_qry[i], reduction='none') qry_losses.append(qry_loss.detach()) qry_accs.append( (qry_logits.argmax(dim=1) == y_qry[i]).detach()) qry_losses = torch.cat(qry_losses).mean().item() qry_accs = 100. * torch.cat(qry_accs).float().mean().item() print( f'[Epoch {epoch+1:.2f}] Test Loss: {qry_losses:.2f} | Acc: {qry_accs:.2f}' ) log.append({ 'epoch': epoch + 1, 'loss': qry_losses, 'acc': qry_accs, 'mode': 'test', 'time': time.time(), })
然而,每当我使用eval时,我的MAML模型就会发散(尽管我的测试是在mini-imagenet上进行的):
>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5941, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5942, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5940, grad_fn=<NormBackward1>)>maml_old (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>>maml_old (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5939, grad_fn=<NormBackward1>)eval_loss=0.9859228551387786, eval_acc=0.5907692521810531args.meta_learner.lr_inner=0.01==== in forward2>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(171440.6875, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(208426.0156, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(17067344., grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(40371.8125, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1.0911e+11, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(21.3515, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(5.4257e+13, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(128.9109, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(3994.7734, grad_fn=<NormBackward1>)>maml_new (before inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(9.5937, grad_fn=<NormBackward1>)>maml_new (after inner adapt): fmodel.model.features.conv1.weight.norm(2)=tensor(1682896., grad_fn=<NormBackward1>)eval_loss_sanity=nan, eval_acc_santiy=0.20000000298023224
那么,应该采取什么措施来避免这种发散现象呢?
注意事项:
- 重新训练非常昂贵。对我来说,使用MAML训练一个5层CNN需要18天。分布式解决方案在这里会非常有帮助 https://github.com/learnables/learn2learn/issues/170
- 或许在训练过程中仅使用train(即使在训练过程中进行评估可能是个好主意,以便将批次统计数据保存到检查点中)
- 或者下次从一开始就使用批次统计数据进行训练
相关链接:
- https://github.com/facebookresearch/higher/issues/107
- https://discuss.pytorch.org/t/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-higher-library/136022
- 如何在PyTorch中使批量归一化层不忘记刚刚使用的批量统计数据?
- https://discuss.pytorch.org/t/how-does-pytorch-s-batch-norm-know-if-the-forward-pass-its-doing-is-for-inference-or-training/16857/10
- https://stats.stackexchange.com/questions/544048/what-does-the-batch-norm-layer-for-maml-model-agnostic-meta-learning-do-for-du/551153#551153
- https://github.com/tristandeleu/pytorch-maml/issues/19
回答:
简而言之:使用mdl.train()
,因为它使用批次统计数据(但推理将不再是确定性的)。 在元学习中,您可能不希望使用mdl.eval()
。
批量归一化的预期行为:
- 重要的是,在推理(eval/测试)期间使用的是running_mean和running_std,这些是从训练中计算得出的(因为他们想要一个确定的输出,并使用总体统计数据的估计)。
- 在训练期间使用批次统计数据,但通过运行平均值来估计总体统计数据。我认为在训练期间使用批次统计数据是为了引入噪声,从而正则化训练(噪声鲁棒性)。
- 在元学习中,我认为在测试期间使用批次统计数据(而不计算运行均值)是最好的,因为我们应该看到新的任务/分布。付出的代价是失去确定性。出于好奇,使用从元训练中估计的总体统计数据的准确性会很有趣。
这可能就是为什么我在使用mdl.train()
进行测试时没有看到发散现象的原因。
因此,请确保使用mdl.train()
(因为它使用批次统计数据 https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d),但不要保存或稍后使用那些作弊的新运行统计数据。