我想训练一个拥有巨大数据集的变分自编码器(VAE),因此决定使用一个为Fashion MNIST设计的VAE 代码,并采用了在GitHub上找到的用于批量加载文件名的流行修改。我的研究协作笔记本在这里这里,以及数据集的一个样本部分。
但是,VAE类的编写方式没有包含根据Keras 文档应有的call函数。我收到了错误信息NotImplementedError: 当子类化Model
类时,您应该实现一个call
方法。
class VAE(tf.keras.Model):"""tensorflow的一个基本vae类Extends: tf.keras.Model"""def __init__(self, **kwargs): super(VAE, self).__init__() self.__dict__.update(kwargs) self.enc = tf.keras.Sequential(self.enc) self.dec = tf.keras.Sequential(self.dec)def encode(self, x): mu, sigma = tf.split(self.enc(x), num_or_size_splits=2, axis=1) return ds.MultivariateNormalDiag(loc=mu, scale_diag=sigma)def reparameterize(self, mean, logvar): eps = tf.random.normal(shape=mean.shape) return eps * tf.exp(logvar * 0.5) + meandef reconstruct(self, x): mu, _ = tf.split(self.enc(x), num_or_size_splits=2, axis=1) return self.decode(mu)def decode(self, z): return self.dec(z)def compute_loss(self, x): q_z = self.encode(x) z = q_z.sample() x_recon = self.decode(z) p_z = ds.MultivariateNormalDiag( loc=[0.] * z.shape[-1], scale_diag=[1.] * z.shape[-1] ) kl_div = ds.kl_divergence(q_z, p_z) latent_loss = tf.reduce_mean(tf.maximum(kl_div, 0)) recon_loss = tf.reduce_mean(tf.reduce_sum(tf.math.square(x - x_recon), axis=0)) return recon_loss, latent_lossdef compute_gradients(self, x): with tf.GradientTape() as tape: loss = self.compute_loss(x) return tape.gradient(loss, self.trainable_variables)@tf.functiondef train(self, train_x): gradients = self.compute_gradients(train_x) self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))
编码器和解码器是分开定义并编译的,如下所示:
N_Z = 8filt_base = 32DIMS = (128,128,3)encoder = [tf.keras.layers.InputLayer(input_shape=DIMS),tf.keras.layers.Conv2D( filters=filt_base, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base*2, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base*2, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base*3, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base*3, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base*4, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D( filters=filt_base*4, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Flatten(),tf.keras.layers.Dense(units=N_Z*2),]decoder = [tf.keras.layers.Dense(units=8 * 8 * 128, activation="relu"),tf.keras.layers.Reshape(target_shape=(8, 8, 128)),tf.keras.layers.Conv2DTranspose( filters=filt_base*4, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=filt_base*4, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=filt_base*3, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=filt_base*3, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=filt_base*2, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=filt_base*2, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=filt_base, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose( filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"),]optimizer = tf.keras.optimizers.Adam(1e-3)model = VAE( enc = encoder, dec = decoder, optimizer = optimizer,)model.compile(optimizer=optimizer)
并尝试使用fit_generator函数训练模型
num_epochs = 50model.fit_generator(generator=my_training_batch_generator, steps_per_epoch=(num_training_samples // batch_size), epochs=num_epochs, verbose=1, validation_data=my_validation_batch_generator, validation_steps=(num_validation_samples // batch_size), use_multiprocessing=True, workers=16, max_queue_size=32)
我是机器学习的新手,任何帮助解决此问题都会被感激。我认为问题出在VAE类中的def train行上。
一个可选的请求是,如果可以进行训练,使我能够在每个epoch后看到重建结果,那将非常感谢。我已经在研究协作笔记本中为此目的准备了一个plot_reconstruction函数,需要调用它。
回答:
APaul31,
在你的代码中,我建议为VAE类添加call()
函数:
def call(self, x): q_z = self.encode(x) z = q_z.sample() x_recon = self.decode(z)
我还建议使用更标准的方法来完成你的任务,特别是作为初学者: