### 当我在Chainer中将重构损失函数从F.bernoulli_nll改为F.mean_squared_error时,VAE无法学习

我想在使用Chainer5.0.0的VAE中,使用mean_squared_error替代F.bernoulli_nll作为重构损失函数。

我是Chainer5.0.0的用户。我已经实现了VAE(变分自编码器)。我参考了以下日文文章。

class VAE(chainer.Chain):    def __init__(self, n_in, n_latent, n_h, act_func=F.tanh):        super(VAE, self).__init__()        self.act_func = act_func        with self.init_scope():            # encoder            self.le1        = L.Linear(n_in, n_h)            self.le2        = L.Linear(n_h,  n_h)            self.le3_mu     = L.Linear(n_h,  n_latent)            self.le3_ln_var = L.Linear(n_h,  n_latent)            # decoder            self.ld1 = L.Linear(n_latent, n_h)            self.ld2 = L.Linear(n_h,      n_h)            self.ld3 = L.Linear(n_h,      n_in)    def __call__(self, x, sigmoid=True):        return self.decode(self.encode(x)[0], sigmoid)    def encode(self, x):        h1 = self.act_func(self.le1(x))        h2 = self.act_func(self.le2(h1))        mu = self.le3_mu(h2)        ln_var = self.le3_ln_var(h2)         return mu, ln_var    def decode(self, z, sigmoid=True):        h1 = self.act_func(self.ld1(z))        h2 = self.act_func(self.ld2(h1))        h3 = self.ld3(h2)        if sigmoid:            return F.sigmoid(h3)        else:            return h3    def get_loss_func(self, C=1.0, k=1):        def lf(x):            mu, ln_var = self.encode(x)            batchsize = len(mu.data)            # reconstruction error            rec_loss = 0            for l in six.moves.range(k):                z = F.gaussian(mu, ln_var)                z.name = "z"                rec_loss += F.bernoulli_nll(x, self.decode(z, sigmoid=False)) / (k * batchsize)            self.rec_loss = rec_loss            self.rec_loss.name = "reconstruction error"            self.latent_loss = C * gaussian_kl_divergence(mu, ln_var) / batchsize            self..name = "latent loss"            self.loss = self.rec_loss + self.latent_loss            self.loss.name = "loss"            return self.loss        return lf

我使用了这段代码,我的VAE已经在MNIST和Fashion-MNIST数据集上进行了训练。训练后,我检查了我的VAE输出与输入图像相似的图像。

rec_loss是重构损失,意味着解码图像与输入图像的差异有多大。我认为我们可以使用mean_squared_error替代F.bernoulli_nll。

所以我修改了代码,如下所示。

rec_loss += F.mean_squared_error(x, self.decode(z)) / k

但是修改代码后,训练结果表现得很奇怪。输出图像相同,这意味着输出图像不依赖于输入图像。

问题出在哪里?

我在日文版Stack Overflow上问了这个问题(https://ja.stackoverflow.com/questions/55477/chainer%E3%81%A7vae%E3%82%92%E4%BD%9C%E3%82%8B%E3%81%A8%E3%81%8D%E3%81%ABloss%E9%96%A2%E6%95%B0%E3%82%92bernoulli-nll%E3%81%A7%E3%81%AF%E3%81%AA%E3%81%8Fmse%E3%82%92%E4%BD%BF%E3%81%86%E3%81%A8%E5%AD%A6%E7%BF%92%E3%81%8C%E9%80%B2%E3%81%BE%E3%81%AA%E3%81%84)。但没有人回应,所以我在这里提交这个问题。

解决方案?

当我将

rec_loss += F.mean_squared_error(x, self.decode(z)) / k 

替换为

rec_loss += F.mean(F.sum((x - self.decode(z)) ** 2, axis=1))

时,问题得到了解决。

但为什么呢?


回答:

除了后者代码使用F.mean(F.sum....仅沿小批量轴取平均(因为它已经沿输入数据维度求和,MNIST展平的情况下为784),而前者沿小批量轴和输入数据维度取平均之外,它们应该是相同的。这意味着,对于展平的MNIST,后者的损失是前者的784倍?我假设k1

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注