我想在使用Chainer5.0.0的VAE中,使用mean_squared_error替代F.bernoulli_nll作为重构损失函数。
我是Chainer5.0.0的用户。我已经实现了VAE(变分自编码器)。我参考了以下日文文章。
- https://qiita.com/kenmatsu4/items/b029d697e9995d93aa24
- https://qiita.com/kenchin110100/items/7ceb5b8e8b21c551d69a
- https://github.com/maguro27/VAE-CIFAR10_chainer
class VAE(chainer.Chain): def __init__(self, n_in, n_latent, n_h, act_func=F.tanh): super(VAE, self).__init__() self.act_func = act_func with self.init_scope(): # encoder self.le1 = L.Linear(n_in, n_h) self.le2 = L.Linear(n_h, n_h) self.le3_mu = L.Linear(n_h, n_latent) self.le3_ln_var = L.Linear(n_h, n_latent) # decoder self.ld1 = L.Linear(n_latent, n_h) self.ld2 = L.Linear(n_h, n_h) self.ld3 = L.Linear(n_h, n_in) def __call__(self, x, sigmoid=True): return self.decode(self.encode(x)[0], sigmoid) def encode(self, x): h1 = self.act_func(self.le1(x)) h2 = self.act_func(self.le2(h1)) mu = self.le3_mu(h2) ln_var = self.le3_ln_var(h2) return mu, ln_var def decode(self, z, sigmoid=True): h1 = self.act_func(self.ld1(z)) h2 = self.act_func(self.ld2(h1)) h3 = self.ld3(h2) if sigmoid: return F.sigmoid(h3) else: return h3 def get_loss_func(self, C=1.0, k=1): def lf(x): mu, ln_var = self.encode(x) batchsize = len(mu.data) # reconstruction error rec_loss = 0 for l in six.moves.range(k): z = F.gaussian(mu, ln_var) z.name = "z" rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize) self.rec_loss = rec_loss self.rec_loss.name = "reconstruction error" self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize self..name = "latent loss" self.loss = self.rec_loss + self.latent_loss self.loss.name = "loss" return self.loss return lf
我使用了这段代码,我的VAE已经在MNIST和Fashion-MNIST数据集上进行了训练。训练后,我检查了我的VAE输出与输入图像相似的图像。
rec_loss是重构损失,意味着解码图像与输入图像的差异有多大。我认为我们可以使用mean_squared_error替代F.bernoulli_nll。
所以我修改了代码,如下所示。
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
但是修改代码后,训练结果表现得很奇怪。输出图像相同,这意味着输出图像不依赖于输入图像。
问题出在哪里?
我在日文版Stack Overflow上问了这个问题(https://ja.stackoverflow.com/questions/55477/chainer%E3%81%A7vae%E3%82%92%E4%BD%9C%E3%82%8B%E3%81%A8%E3%81%8D%E3%81%ABloss%E9%96%A2%E6%95%B0%E3%82%92bernoulli-nll%E3%81%A7%E3%81%AF%E3%81%AA%E3%81%8Fmse%E3%82%92%E4%BD%BF%E3%81%86%E3%81%A8%E5%AD%A6%E7%BF%92%E3%81%8C%E9%80%B2%E3%81%BE%E3%81%AA%E3%81%84)。但没有人回应,所以我在这里提交这个问题。
解决方案?
当我将
rec_loss += F.mean_squared_error(x, self.decode(z)) / k
替换为
rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))
时,问题得到了解决。
但为什么呢?
回答:
除了后者代码使用F.mean(F.sum....
仅沿小批量轴取平均(因为它已经沿输入数据维度求和,MNIST展平的情况下为784),而前者沿小批量轴和输入数据维度取平均之外,它们应该是相同的。这意味着,对于展平的MNIST,后者的损失是前者的784倍?我假设k
是1
。