这是我第一次使用GANs,我遇到了一个问题,判别器反复超越生成器。我试图重现这篇文章中的PA
模型,并且我正在参考这个略有不同的实现来帮助我解决问题。
我已经阅读了很多关于GANs如何工作的论文,并且还跟了一些教程来更好地理解它们。此外,我还阅读了一些关于如何克服主要不稳定性的文章,但我找不到克服这种行为的方法。
在我的环境中,我使用的是PyTorch
和BCELoss()
。按照DCGAN PyTorch教程,我使用以下训练循环:
criterion = nn.BCELoss()train_d = False# Discriminator trueoptim_d.zero_grad()disc_train_real = target.to(device)batch_size = disc_train_real.size(0)label = torch.full((batch_size,), 1, device=device).cuda()output_d = discriminator(disc_train_real).view(-1)loss_d_real = criterion(output_d, label).cuda()if lossT: loss_d_real *= 2if loss_d_real.item() > 0.3: loss_d_real.backward() train_d = TrueD_x = output_d.mean().item()# Discriminator falseoutput_g = generator(image)output_d = discriminator(output_g.detach()).view(-1)label.fill_(0)loss_d_fake = criterion(output_d, label).cuda()D_G_z1 = output_d.mean().item()if lossT: loss_d_fake *= 2loss_d = loss_d_real + loss_d_fakeif loss_d_fake.item() > 0.3: loss_d_fake.backward() train_d = Trueif train_d: optim_d.step()# Generatorlabel.fill_(1)output_d = discriminator(output_g).view(-1)loss_g = criterion(output_d, label).cuda()D_G_z2 = output_d.mean().item()if lossT: loss_g *= 2loss_g.backward()optim_g.step()
并且,在一段时间的稳定之后,一切似乎都运行良好:
Epoch 1/5 - Step: 1900/9338 Loss G: 3.057388 Loss D: 0.214545 D(x): 0.940985 D(G(z)): 0.114064 / 0.114064Time for the last step: 51.55 s Epoch ETA: 01:04:13Epoch 1/5 - Step: 2000/9338 Loss G: 2.984724 Loss D: 0.222931 D(x): 0.879338 D(G(z)): 0.159163 / 0.159163Time for the last step: 52.68 s Epoch ETA: 01:03:24Epoch 1/5 - Step: 2100/9338 Loss G: 2.824713 Loss D: 0.241953 D(x): 0.905837 D(G(z)): 0.110231 / 0.110231Time for the last step: 50.91 s Epoch ETA: 01:02:29Epoch 1/5 - Step: 2200/9338 Loss G: 2.807455 Loss D: 0.252808 D(x): 0.908131 D(G(z)): 0.218515 / 0.218515Time for the last step: 51.72 s Epoch ETA: 01:01:37Epoch 1/5 - Step: 2300/9338 Loss G: 2.470529 Loss D: 0.569696 D(x): 0.620966 D(G(z)): 0.512615 / 0.350175Time for the last step: 51.96 s Epoch ETA: 01:00:46Epoch 1/5 - Step: 2400/9338 Loss G: 2.148863 Loss D: 1.071563 D(x): 0.809529 D(G(z)): 0.114487 / 0.114487Time for the last step: 51.59 s Epoch ETA: 00:59:53Epoch 1/5 - Step: 2500/9338 Loss G: 2.016863 Loss D: 0.904711 D(x): 0.621433 D(G(z)): 0.440721 / 0.435932Time for the last step: 52.03 s Epoch ETA: 00:59:02Epoch 1/5 - Step: 2600/9338 Loss G: 2.495639 Loss D: 0.949308 D(x): 0.671085 D(G(z)): 0.557924 / 0.420826Time for the last step: 52.66 s Epoch ETA: 00:58:12Epoch 1/5 - Step: 2700/9338 Loss G: 2.519842 Loss D: 0.798667 D(x): 0.775738 D(G(z)): 0.246357 / 0.265839Time for the last step: 51.20 s Epoch ETA: 00:57:19Epoch 1/5 - Step: 2800/9338 Loss G: 2.545630 Loss D: 0.756449 D(x): 0.895455 D(G(z)): 0.403628 / 0.301851Time for the last step: 51.88 s Epoch ETA: 00:56:27Epoch 1/5 - Step: 2900/9338 Loss G: 2.458109 Loss D: 0.653513 D(x): 0.820105 D(G(z)): 0.379199 / 0.103250Time for the last step: 53.50 s Epoch ETA: 00:55:39Epoch 1/5 - Step: 3000/9338 Loss G: 2.030103 Loss D: 0.948208 D(x): 0.445385 D(G(z)): 0.303225 / 0.263652Time for the last step: 51.57 s Epoch ETA: 00:54:47Epoch 1/5 - Step: 3100/9338 Loss G: 1.721604 Loss D: 0.949721 D(x): 0.365646 D(G(z)): 0.090072 / 0.232912Time for the last step: 52.19 s Epoch ETA: 00:53:55Epoch 1/5 - Step: 3200/9338 Loss G: 1.438854 Loss D: 1.142182 D(x): 0.768163 D(G(z)): 0.321164 / 0.237878Time for the last step: 50.79 s Epoch ETA: 00:53:01Epoch 1/5 - Step: 3300/9338 Loss G: 1.924418 Loss D: 0.923860 D(x): 0.729981 D(G(z)): 0.354812 / 0.318090Time for the last step: 52.59 s Epoch ETA: 00:52:11
也就是说,生成器的梯度较高,并且在一段时间后开始下降,而与此同时,判别器的梯度开始上升。至于损失,生成器的损失下降,而判别器的损失上升。与教程相比,我认为这是可以接受的。
这是我的第一个问题:我注意到在教程中(通常)当D_G_z1
上升时,D_G_z2
下降(反之亦然),而在我的例子中,这种情况发生得少得多。这是巧合还是我做错了什么?
鉴于此,我让训练过程继续进行,但现在我注意到了这一点:
Epoch 3/5 - Step: 1100/9338 Loss G: 4.071329 Loss D: 0.031608 D(x): 0.999969 D(G(z)): 0.024329 / 0.024329Time for the last step: 51.41 s Epoch ETA: 01:11:24Epoch 3/5 - Step: 1200/9338 Loss G: 3.883331 Loss D: 0.036354 D(x): 0.999993 D(G(z)): 0.043874 / 0.043874Time for the last step: 51.63 s Epoch ETA: 01:10:29Epoch 3/5 - Step: 1300/9338 Loss G: 3.468963 Loss D: 0.054542 D(x): 0.999972 D(G(z)): 0.050145 / 0.050145Time for the last step: 52.47 s Epoch ETA: 01:09:40Epoch 3/5 - Step: 1400/9338 Loss G: 3.504971 Loss D: 0.053683 D(x): 0.999972 D(G(z)): 0.052180 / 0.052180Time for the last step: 50.75 s Epoch ETA: 01:08:41Epoch 3/5 - Step: 1500/9338 Loss G: 3.437765 Loss D: 0.056286 D(x): 0.999941 D(G(z)): 0.058839 / 0.058839Time for the last step: 52.20 s Epoch ETA: 01:07:50Epoch 3/5 - Step: 1600/9338 Loss G: 3.369209 Loss D: 0.062133 D(x): 0.955688 D(G(z)): 0.058773 / 0.058773Time for the last step: 51.05 s Epoch ETA: 01:06:54Epoch 3/5 - Step: 1700/9338 Loss G: 3.290109 Loss D: 0.065704 D(x): 0.999975 D(G(z)): 0.056583 / 0.056583Time for the last step: 51.27 s Epoch ETA: 01:06:00Epoch 3/5 - Step: 1800/9338 Loss G: 3.286248 Loss D: 0.067969 D(x): 0.993238 D(G(z)): 0.063815 / 0.063815Time for the last step: 52.28 s Epoch ETA: 01:05:09Epoch 3/5 - Step: 1900/9338 Loss G: 3.263996 Loss D: 0.065335 D(x): 0.980270 D(G(z)): 0.037717 / 0.037717Time for the last step: 51.59 s Epoch ETA: 01:04:16Epoch 3/5 - Step: 2000/9338 Loss G: 3.293503 Loss D: 0.065291 D(x): 0.999873 D(G(z)): 0.070188 / 0.070188Time for the last step: 51.85 s Epoch ETA: 01:03:25Epoch 3/5 - Step: 2100/9338 Loss G: 3.184164 Loss D: 0.070931 D(x): 0.999971 D(G(z)): 0.059657 / 0.059657Time for the last step: 52.14 s Epoch ETA: 01:02:34Epoch 3/5 - Step: 2200/9338 Loss G: 3.116310 Loss D: 0.080597 D(x): 0.999850 D(G(z)): 0.074931 / 0.074931Time for the last step: 51.85 s Epoch ETA: 01:01:42Epoch 3/5 - Step: 2300/9338 Loss G: 3.142180 Loss D: 0.073999 D(x): 0.995546 D(G(z)): 0.054752 / 0.054752Time for the last step: 51.76 s Epoch ETA: 01:00:50Epoch 3/5 - Step: 2400/9338 Loss G: 3.185711 Loss D: 0.072601 D(x): 0.999992 D(G(z)): 0.076053 / 0.076053Time for the last step: 50.53 s Epoch ETA: 00:59:54Epoch 3/5 - Step: 2500/9338 Loss G: 3.027437 Loss D: 0.083906 D(x): 0.997390 D(G(z)): 0.082501 / 0.082501Time for the last step: 52.06 s Epoch ETA: 00:59:03Epoch 3/5 - Step: 2600/9338 Loss G: 3.052374 Loss D: 0.085030 D(x): 0.999924 D(G(z)): 0.073295 / 0.073295Time for the last step: 52.37 s Epoch ETA: 00:58:12
不仅D(x)
再次增加并几乎固定在1,而且D_G_z1
和D_G_z2
总是显示相同的值。此外,从损失来看,很明显判别器已经超越了生成器。这种行为一直持续到这一轮的剩余时间以及下一整轮,直到训练结束。
因此我的第二个问题:这是正常的吗?如果不是,我在程序中做错了什么?如何实现更稳定的训练?
编辑:我尝试使用建议的MSELoss()
来训练网络,这是输出结果:
Epoch 1/1 - Step: 100/9338 Loss G: 0.800785 Loss D: 0.404525 D(x): 0.844653 D(G(z)): 0.030439 / 0.016316Time for the last step: 55.22 s Epoch ETA: 01:25:01Epoch 1/1 - Step: 200/9338 Loss G: 1.196659 Loss D: 0.014051 D(x): 0.999970 D(G(z)): 0.006543 / 0.006500Time for the last step: 51.41 s Epoch ETA: 01:21:11Epoch 1/1 - Step: 300/9338 Loss G: 1.197319 Loss D: 0.000806 D(x): 0.999431 D(G(z)): 0.004821 / 0.004724Time for the last step: 51.79 s Epoch ETA: 01:19:32Epoch 1/1 - Step: 400/9338 Loss G: 1.198960 Loss D: 0.000720 D(x): 0.999612 D(G(z)): 0.000000 / 0.000000Time for the last step: 51.47 s Epoch ETA: 01:18:09Epoch 1/1 - Step: 500/9338 Loss G: 1.212810 Loss D: 0.000021 D(x): 0.999938 D(G(z)): 0.000000 / 0.000000Time for the last step: 52.18 s Epoch ETA: 01:17:11Epoch 1/1 - Step: 600/9338 Loss G: 1.216168 Loss D: 0.000000 D(x): 0.999945 D(G(z)): 0.000000 / 0.000000Time for the last step: 51.24 s Epoch ETA: 01:16:02Epoch 1/1 - Step: 700/9338 Loss G: 1.212301 Loss D: 0.000000 D(x): 0.999970 D(G(z)): 0.000000 / 0.000000Time for the last step: 51.61 s Epoch ETA: 01:15:02Epoch 1/1 - Step: 800/9338 Loss G: 1.214397 Loss D: 0.000005 D(x): 0.999973 D(G(z)): 0.000000 / 0.000000Time for the last step: 51.58 s Epoch ETA: 01:14:04Epoch 1/1 - Step: 900/9338 Loss G: 1.212016 Loss D: 0.000003 D(x): 0.999932 D(G(z)): 0.000000 / 0.000000Time for the last step: 52.20 s Epoch ETA: 01:13:13Epoch 1/1 - Step: 1000/9338 Loss G: 1.215162 Loss D: 0.000000 D(x): 0.999988 D(G(z)): 0.000000 / 0.000000Time for the last step: 52.28 s Epoch ETA: 01:12:23Epoch 1/1 - Step: 1100/9338 Loss G: 1.216291 Loss D: 0.000000 D(x): 0.999983 D(G(z)): 0.000000 / 0.000000Time for the last step: 51.78 s Epoch ETA: 01:11:28Epoch 1/1 - Step: 1200/9338 Loss G: 1.215526 Loss D: 0.000000 D(x): 0.999978 D(G(z)): 0.000000 / 0.000000Time for the last step: 51.88 s Epoch ETA: 01:10:35
可以看出,情况变得更糟。此外,重新阅读EnhanceNet论文,第4.2.4节(对抗性训练)指出使用的对抗性损失函数是BCELoss()
,正如我所期望的,以解决使用MSELoss()
时出现的梯度消失问题。
回答:
解释GAN损失有点像一门黑魔法,因为实际的损失值
问题1:判别器/生成器主导之间的摆动频率会因几个因素而异(根据我的经验):学习率和批量大小,这些会影响传播的损失。使用的特定损失指标会影响D和G网络训练的方差。EnhanceNet论文(作为基线)和教程也使用了均方误差损失——你使用的是二元交叉熵损失,这将改变网络的收敛速度。我不是专家,所以这里有一个非常好的链接到Rohan Varma的关于损失函数差异的文章。我想知道当你改变损失函数时,你的网络是否会有不同的表现——试试看并更新问题?
问题2:随着时间的推移,D和G的损失应该稳定在一个值上,然而,很难判断它们是由于强性能而收敛,还是由于诸如模式崩溃/梯度消失(Jonathan Hui关于训练GANs问题的解释)等原因而收敛。我发现的最好方法是实际检查生成图像的横截面,并通过视觉检查输出或使用某种感知度量(SSIM,PSNR,PIQ等)来评估生成图像集。
你可能还会发现以下一些有用的线索来寻找答案:
这个帖子对解释GAN损失有几个相当不错的指导意见。
Ian Goodfellow的NIPS2016教程也有一些关于如何平衡D和G训练的可靠想法。