修改Keras类以包含call函数的问题

我想训练一个拥有巨大数据集的变分自编码器(VAE),因此决定使用一个为Fashion MNIST设计的VAE 代码,并采用了在GitHub上找到的用于批量加载文件名的流行修改。我的研究协作笔记本在这里这里,以及数据集的一个样本部分。

但是,VAE类的编写方式没有包含根据Keras 文档应有的call函数。我收到了错误信息NotImplementedError: 当子类化Model类时,您应该实现一个call方法。

class VAE(tf.keras.Model):"""tensorflow的一个基本vae类Extends:    tf.keras.Model"""def __init__(self, **kwargs):    super(VAE, self).__init__()    self.__dict__.update(kwargs)    self.enc = tf.keras.Sequential(self.enc)    self.dec = tf.keras.Sequential(self.dec)def encode(self, x):    mu, sigma = tf.split(self.enc(x), num_or_size_splits=2, axis=1)    return ds.MultivariateNormalDiag(loc=mu, scale_diag=sigma)def reparameterize(self, mean, logvar):    eps = tf.random.normal(shape=mean.shape)    return eps * tf.exp(logvar * 0.5) + meandef reconstruct(self, x):    mu, _ = tf.split(self.enc(x), num_or_size_splits=2, axis=1)    return self.decode(mu)def decode(self, z):    return self.dec(z)def compute_loss(self, x):    q_z = self.encode(x)    z = q_z.sample()    x_recon = self.decode(z)    p_z = ds.MultivariateNormalDiag(      loc=[0.] * z.shape[-1], scale_diag=[1.] * z.shape[-1]      )    kl_div = ds.kl_divergence(q_z, p_z)    latent_loss = tf.reduce_mean(tf.maximum(kl_div, 0))    recon_loss = tf.reduce_mean(tf.reduce_sum(tf.math.square(x - x_recon), axis=0))    return recon_loss, latent_lossdef compute_gradients(self, x):    with tf.GradientTape() as tape:        loss = self.compute_loss(x)    return tape.gradient(loss, self.trainable_variables)@tf.functiondef train(self, train_x):    gradients = self.compute_gradients(train_x)    self.optimizer.apply_gradients(zip(gradients, self.trainable_variables))

编码器和解码器是分开定义并编译的,如下所示:

N_Z = 8filt_base = 32DIMS = (128,128,3)encoder = [tf.keras.layers.InputLayer(input_shape=DIMS),tf.keras.layers.Conv2D(    filters=filt_base, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base*2, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base*2, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base*3, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base*3, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base*4, kernel_size=3, strides=(1, 1), activation="relu", padding="same"),tf.keras.layers.Conv2D(    filters=filt_base*4, kernel_size=3, strides=(2, 2), activation="relu", padding="same"),tf.keras.layers.Flatten(),tf.keras.layers.Dense(units=N_Z*2),]decoder = [tf.keras.layers.Dense(units=8 * 8 * 128, activation="relu"),tf.keras.layers.Reshape(target_shape=(8, 8, 128)),tf.keras.layers.Conv2DTranspose(    filters=filt_base*4, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=filt_base*4, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=filt_base*3, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=filt_base*3, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=filt_base*2, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=filt_base*2, kernel_size=3, strides=(1, 1), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=filt_base, kernel_size=3, strides=(2, 2), padding="SAME", activation="relu"),tf.keras.layers.Conv2DTranspose(    filters=1, kernel_size=3, strides=(1, 1), padding="SAME", activation="sigmoid"),]optimizer = tf.keras.optimizers.Adam(1e-3)model = VAE(  enc = encoder,  dec = decoder,  optimizer = optimizer,)model.compile(optimizer=optimizer)

并尝试使用fit_generator函数训练模型

num_epochs = 50model.fit_generator(generator=my_training_batch_generator,                                      steps_per_epoch=(num_training_samples // batch_size),                                      epochs=num_epochs,                                      verbose=1,                                      validation_data=my_validation_batch_generator,                                      validation_steps=(num_validation_samples // batch_size),                                      use_multiprocessing=True,                                      workers=16,                                      max_queue_size=32)

我是机器学习的新手,任何帮助解决此问题都会被感激。我认为问题出在VAE类中的def train行上。

一个可选的请求是,如果可以进行训练,使我能够在每个epoch后看到重建结果,那将非常感谢。我已经在研究协作笔记本中为此目的准备了一个plot_reconstruction函数,需要调用它。


回答:

APaul31,

在你的代码中,我建议为VAE类添加call()函数:

def call(self, x):    q_z = self.encode(x)    z = q_z.sample()    x_recon = self.decode(z)

我还建议使用更标准的方法来完成你的任务,特别是作为初学者:

  1. 使用tf.keras.preprocessing.image_dataset_from_directory()来加载图像。教程在这里

  2. 使用自定义的Model.train_step()来计算VAE损失,而不是在你的VAE类中使用多个函数。示例在这里

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注