我知道,在使用loss.backward()
时,如果有多个网络和多个损失函数需要分别优化每个网络,我们需要指定retain_graph=True
。但即使指定(或不指定)这个参数,我还是会遇到错误。以下是一个最小工作示例(MWE),用于重现问题(在PyTorch 1.6上)。
当retain_graph
设置为True
时,我会得到以下错误:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 300]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!
如果不使用该参数,得到的错误是:
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
这是预期的错误。
请指出上述代码中需要做哪些更改才能开始训练。任何帮助都将不胜感激。
回答:
在这种情况下,可以断开计算图以排除不需要优化的参数。在本例中,应在第二次使用gru1
进行前向传播后断开计算图,即:
....gru1_opt.step()gru1_output, _ = gru1(vector)gru1_output = gru1_output.detach()....
这样,您就不会“尝试第二次通过图进行反向传播”,如错误消息中提到的。