如何在GAN中平衡生成器和判别器的性能?

这是我第一次使用GANs,我遇到了一个问题,判别器反复超越生成器。我试图重现这篇文章中的PA模型,并且我正在参考这个略有不同的实现来帮助我解决问题。

我已经阅读了很多关于GANs如何工作的论文,并且还跟了一些教程来更好地理解它们。此外,我还阅读了一些关于如何克服主要不稳定性的文章,但我找不到克服这种行为的方法。

在我的环境中,我使用的是PyTorchBCELoss()。按照DCGAN PyTorch教程,我使用以下训练循环:

criterion = nn.BCELoss()train_d = False# Discriminator trueoptim_d.zero_grad()disc_train_real = target.to(device)batch_size = disc_train_real.size(0)label = torch.full((batch_size,), 1, device=device).cuda()output_d = discriminator(disc_train_real).view(-1)loss_d_real = criterion(output_d, label).cuda()if lossT:    loss_d_real *= 2if loss_d_real.item() > 0.3:    loss_d_real.backward()    train_d = TrueD_x = output_d.mean().item()# Discriminator falseoutput_g = generator(image)output_d = discriminator(output_g.detach()).view(-1)label.fill_(0)loss_d_fake = criterion(output_d, label).cuda()D_G_z1 = output_d.mean().item()if lossT:    loss_d_fake *= 2loss_d = loss_d_real + loss_d_fakeif loss_d_fake.item() > 0.3:    loss_d_fake.backward()    train_d = Trueif train_d:    optim_d.step()# Generatorlabel.fill_(1)output_d = discriminator(output_g).view(-1)loss_g = criterion(output_d, label).cuda()D_G_z2 = output_d.mean().item()if lossT:    loss_g *= 2loss_g.backward()optim_g.step()

并且,在一段时间的稳定之后,一切似乎都运行良好:

Epoch 1/5 - Step: 1900/9338  Loss G: 3.057388  Loss D: 0.214545  D(x): 0.940985  D(G(z)): 0.114064 / 0.114064Time for the last step: 51.55 s    Epoch ETA: 01:04:13Epoch 1/5 - Step: 2000/9338  Loss G: 2.984724  Loss D: 0.222931  D(x): 0.879338  D(G(z)): 0.159163 / 0.159163Time for the last step: 52.68 s    Epoch ETA: 01:03:24Epoch 1/5 - Step: 2100/9338  Loss G: 2.824713  Loss D: 0.241953  D(x): 0.905837  D(G(z)): 0.110231 / 0.110231Time for the last step: 50.91 s    Epoch ETA: 01:02:29Epoch 1/5 - Step: 2200/9338  Loss G: 2.807455  Loss D: 0.252808  D(x): 0.908131  D(G(z)): 0.218515 / 0.218515Time for the last step: 51.72 s    Epoch ETA: 01:01:37Epoch 1/5 - Step: 2300/9338  Loss G: 2.470529  Loss D: 0.569696  D(x): 0.620966  D(G(z)): 0.512615 / 0.350175Time for the last step: 51.96 s    Epoch ETA: 01:00:46Epoch 1/5 - Step: 2400/9338  Loss G: 2.148863  Loss D: 1.071563  D(x): 0.809529  D(G(z)): 0.114487 / 0.114487Time for the last step: 51.59 s    Epoch ETA: 00:59:53Epoch 1/5 - Step: 2500/9338  Loss G: 2.016863  Loss D: 0.904711  D(x): 0.621433  D(G(z)): 0.440721 / 0.435932Time for the last step: 52.03 s    Epoch ETA: 00:59:02Epoch 1/5 - Step: 2600/9338  Loss G: 2.495639  Loss D: 0.949308  D(x): 0.671085  D(G(z)): 0.557924 / 0.420826Time for the last step: 52.66 s    Epoch ETA: 00:58:12Epoch 1/5 - Step: 2700/9338  Loss G: 2.519842  Loss D: 0.798667  D(x): 0.775738  D(G(z)): 0.246357 / 0.265839Time for the last step: 51.20 s    Epoch ETA: 00:57:19Epoch 1/5 - Step: 2800/9338  Loss G: 2.545630  Loss D: 0.756449  D(x): 0.895455  D(G(z)): 0.403628 / 0.301851Time for the last step: 51.88 s    Epoch ETA: 00:56:27Epoch 1/5 - Step: 2900/9338  Loss G: 2.458109  Loss D: 0.653513  D(x): 0.820105  D(G(z)): 0.379199 / 0.103250Time for the last step: 53.50 s    Epoch ETA: 00:55:39Epoch 1/5 - Step: 3000/9338  Loss G: 2.030103  Loss D: 0.948208  D(x): 0.445385  D(G(z)): 0.303225 / 0.263652Time for the last step: 51.57 s    Epoch ETA: 00:54:47Epoch 1/5 - Step: 3100/9338  Loss G: 1.721604  Loss D: 0.949721  D(x): 0.365646  D(G(z)): 0.090072 / 0.232912Time for the last step: 52.19 s    Epoch ETA: 00:53:55Epoch 1/5 - Step: 3200/9338  Loss G: 1.438854  Loss D: 1.142182  D(x): 0.768163  D(G(z)): 0.321164 / 0.237878Time for the last step: 50.79 s    Epoch ETA: 00:53:01Epoch 1/5 - Step: 3300/9338  Loss G: 1.924418  Loss D: 0.923860  D(x): 0.729981  D(G(z)): 0.354812 / 0.318090Time for the last step: 52.59 s    Epoch ETA: 00:52:11

也就是说,生成器的梯度较高,并且在一段时间后开始下降,而与此同时,判别器的梯度开始上升。至于损失,生成器的损失下降,而判别器的损失上升。与教程相比,我认为这是可以接受的。

这是我的第一个问题:我注意到在教程中(通常)当D_G_z1上升时,D_G_z2下降(反之亦然),而在我的例子中,这种情况发生得少得多。这是巧合还是我做错了什么?

鉴于此,我让训练过程继续进行,但现在我注意到了这一点:

Epoch 3/5 - Step: 1100/9338  Loss G: 4.071329  Loss D: 0.031608  D(x): 0.999969  D(G(z)): 0.024329 / 0.024329Time for the last step: 51.41 s    Epoch ETA: 01:11:24Epoch 3/5 - Step: 1200/9338  Loss G: 3.883331  Loss D: 0.036354  D(x): 0.999993  D(G(z)): 0.043874 / 0.043874Time for the last step: 51.63 s    Epoch ETA: 01:10:29Epoch 3/5 - Step: 1300/9338  Loss G: 3.468963  Loss D: 0.054542  D(x): 0.999972  D(G(z)): 0.050145 / 0.050145Time for the last step: 52.47 s    Epoch ETA: 01:09:40Epoch 3/5 - Step: 1400/9338  Loss G: 3.504971  Loss D: 0.053683  D(x): 0.999972  D(G(z)): 0.052180 / 0.052180Time for the last step: 50.75 s    Epoch ETA: 01:08:41Epoch 3/5 - Step: 1500/9338  Loss G: 3.437765  Loss D: 0.056286  D(x): 0.999941  D(G(z)): 0.058839 / 0.058839Time for the last step: 52.20 s    Epoch ETA: 01:07:50Epoch 3/5 - Step: 1600/9338  Loss G: 3.369209  Loss D: 0.062133  D(x): 0.955688  D(G(z)): 0.058773 / 0.058773Time for the last step: 51.05 s    Epoch ETA: 01:06:54Epoch 3/5 - Step: 1700/9338  Loss G: 3.290109  Loss D: 0.065704  D(x): 0.999975  D(G(z)): 0.056583 / 0.056583Time for the last step: 51.27 s    Epoch ETA: 01:06:00Epoch 3/5 - Step: 1800/9338  Loss G: 3.286248  Loss D: 0.067969  D(x): 0.993238  D(G(z)): 0.063815 / 0.063815Time for the last step: 52.28 s    Epoch ETA: 01:05:09Epoch 3/5 - Step: 1900/9338  Loss G: 3.263996  Loss D: 0.065335  D(x): 0.980270  D(G(z)): 0.037717 / 0.037717Time for the last step: 51.59 s    Epoch ETA: 01:04:16Epoch 3/5 - Step: 2000/9338  Loss G: 3.293503  Loss D: 0.065291  D(x): 0.999873  D(G(z)): 0.070188 / 0.070188Time for the last step: 51.85 s    Epoch ETA: 01:03:25Epoch 3/5 - Step: 2100/9338  Loss G: 3.184164  Loss D: 0.070931  D(x): 0.999971  D(G(z)): 0.059657 / 0.059657Time for the last step: 52.14 s    Epoch ETA: 01:02:34Epoch 3/5 - Step: 2200/9338  Loss G: 3.116310  Loss D: 0.080597  D(x): 0.999850  D(G(z)): 0.074931 / 0.074931Time for the last step: 51.85 s    Epoch ETA: 01:01:42Epoch 3/5 - Step: 2300/9338  Loss G: 3.142180  Loss D: 0.073999  D(x): 0.995546  D(G(z)): 0.054752 / 0.054752Time for the last step: 51.76 s    Epoch ETA: 01:00:50Epoch 3/5 - Step: 2400/9338  Loss G: 3.185711  Loss D: 0.072601  D(x): 0.999992  D(G(z)): 0.076053 / 0.076053Time for the last step: 50.53 s    Epoch ETA: 00:59:54Epoch 3/5 - Step: 2500/9338  Loss G: 3.027437  Loss D: 0.083906  D(x): 0.997390  D(G(z)): 0.082501 / 0.082501Time for the last step: 52.06 s    Epoch ETA: 00:59:03Epoch 3/5 - Step: 2600/9338  Loss G: 3.052374  Loss D: 0.085030  D(x): 0.999924  D(G(z)): 0.073295 / 0.073295Time for the last step: 52.37 s    Epoch ETA: 00:58:12

不仅D(x)再次增加并几乎固定在1,而且D_G_z1D_G_z2总是显示相同的值。此外,从损失来看,很明显判别器已经超越了生成器。这种行为一直持续到这一轮的剩余时间以及下一整轮,直到训练结束。

因此我的第二个问题:这是正常的吗?如果不是,我在程序中做错了什么?如何实现更稳定的训练?

编辑:我尝试使用建议的MSELoss()来训练网络,这是输出结果:

Epoch 1/1 - Step: 100/9338  Loss G: 0.800785  Loss D: 0.404525  D(x): 0.844653  D(G(z)): 0.030439 / 0.016316Time for the last step: 55.22 s    Epoch ETA: 01:25:01Epoch 1/1 - Step: 200/9338  Loss G: 1.196659  Loss D: 0.014051  D(x): 0.999970  D(G(z)): 0.006543 / 0.006500Time for the last step: 51.41 s    Epoch ETA: 01:21:11Epoch 1/1 - Step: 300/9338  Loss G: 1.197319  Loss D: 0.000806  D(x): 0.999431  D(G(z)): 0.004821 / 0.004724Time for the last step: 51.79 s    Epoch ETA: 01:19:32Epoch 1/1 - Step: 400/9338  Loss G: 1.198960  Loss D: 0.000720  D(x): 0.999612  D(G(z)): 0.000000 / 0.000000Time for the last step: 51.47 s    Epoch ETA: 01:18:09Epoch 1/1 - Step: 500/9338  Loss G: 1.212810  Loss D: 0.000021  D(x): 0.999938  D(G(z)): 0.000000 / 0.000000Time for the last step: 52.18 s    Epoch ETA: 01:17:11Epoch 1/1 - Step: 600/9338  Loss G: 1.216168  Loss D: 0.000000  D(x): 0.999945  D(G(z)): 0.000000 / 0.000000Time for the last step: 51.24 s    Epoch ETA: 01:16:02Epoch 1/1 - Step: 700/9338  Loss G: 1.212301  Loss D: 0.000000  D(x): 0.999970  D(G(z)): 0.000000 / 0.000000Time for the last step: 51.61 s    Epoch ETA: 01:15:02Epoch 1/1 - Step: 800/9338  Loss G: 1.214397  Loss D: 0.000005  D(x): 0.999973  D(G(z)): 0.000000 / 0.000000Time for the last step: 51.58 s    Epoch ETA: 01:14:04Epoch 1/1 - Step: 900/9338  Loss G: 1.212016  Loss D: 0.000003  D(x): 0.999932  D(G(z)): 0.000000 / 0.000000Time for the last step: 52.20 s    Epoch ETA: 01:13:13Epoch 1/1 - Step: 1000/9338  Loss G: 1.215162  Loss D: 0.000000  D(x): 0.999988  D(G(z)): 0.000000 / 0.000000Time for the last step: 52.28 s    Epoch ETA: 01:12:23Epoch 1/1 - Step: 1100/9338  Loss G: 1.216291  Loss D: 0.000000  D(x): 0.999983  D(G(z)): 0.000000 / 0.000000Time for the last step: 51.78 s    Epoch ETA: 01:11:28Epoch 1/1 - Step: 1200/9338  Loss G: 1.215526  Loss D: 0.000000  D(x): 0.999978  D(G(z)): 0.000000 / 0.000000Time for the last step: 51.88 s    Epoch ETA: 01:10:35

可以看出,情况变得更糟。此外,重新阅读EnhanceNet论文,第4.2.4节(对抗性训练)指出使用的对抗性损失函数是BCELoss(),正如我所期望的,以解决使用MSELoss()时出现的梯度消失问题。


回答:

解释GAN损失有点像一门黑魔法,因为实际的损失值

问题1:判别器/生成器主导之间的摆动频率会因几个因素而异(根据我的经验):学习率和批量大小,这些会影响传播的损失。使用的特定损失指标会影响D和G网络训练的方差。EnhanceNet论文(作为基线)和教程也使用了均方误差损失——你使用的是二元交叉熵损失,这将改变网络的收敛速度。我不是专家,所以这里有一个非常好的链接到Rohan Varma的关于损失函数差异的文章。我想知道当你改变损失函数时,你的网络是否会有不同的表现——试试看并更新问题?

问题2:随着时间的推移,D和G的损失应该稳定在一个值上,然而,很难判断它们是由于强性能而收敛,还是由于诸如模式崩溃/梯度消失(Jonathan Hui关于训练GANs问题的解释)等原因而收敛。我发现的最好方法是实际检查生成图像的横截面,并通过视觉检查输出或使用某种感知度量(SSIM,PSNR,PIQ等)来评估生成图像集。

你可能还会发现以下一些有用的线索来寻找答案:

这个帖子对解释GAN损失有几个相当不错的指导意见。

Ian Goodfellow的NIPS2016教程也有一些关于如何平衡D和G训练的可靠想法。

Related Posts

神经网络反向传播代码不工作

我需要编写一个简单的由1个输出节点、1个包含3个节点的…

值错误:y 包含先前未见过的标签:

我使用了 决策树分类器,我想将我的 输入 作为 字符串…

使用不平衡数据集进行特征选择时遇到的问题

我正在使用不平衡数据集(54:38:7%)进行特征选择…

广义随机森林/因果森林在Python上的应用

我在寻找Python上的广义随机森林/因果森林算法,但…

如何用PyTorch仅用标量损失来训练神经网络?

假设我们有一个神经网络,我们希望它能根据输入预测三个值…

什么是RNN中间隐藏状态的良好用途?

我已经以三种不同的方式使用了RNN/LSTM: 多对多…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注