我正在学习如何自定义训练循环的教程
最后一个例子展示了一个使用自定义训练实现的GAN,其中只定义了__init__
、train_step
和compile
方法
class GAN(keras.Model): def __init__(self, discriminator, generator, latent_dim): super(GAN, self).__init__() self.discriminator = discriminator self.generator = generator self.latent_dim = latent_dim def compile(self, d_optimizer, g_optimizer, loss_fn): super(GAN, self).compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn def train_step(self, real_images): if isinstance(real_images, tuple): real_images = real_images[0] ...
如果我的模型也有一个自定义的call()
函数会怎样?train_step()
会覆盖call()
吗?call()
和train_step()
不是都被fit()
调用吗?两者之间有什么区别?
下面是我写的另一段代码,我在想fit()
中调用的是call()
还是train_step()
:
class MyModel(tf.keras.Model): def __init__(self, vocab_size, embedding_dim, rnn_units): super().__init__(self) self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim) self.gru = tf.keras.layers.GRU(rnn_units, return_sequences=True, return_state=True, reset_after=True ) self.dense = tf.keras.layers.Dense(vocab_size) def call(self, inputs, states=None, return_state=False, training=False): x = inputs x = self.embedding(x, training=training) if states is None: states = self.gru.get_initial_state(x) x, states = self.gru(x, initial_state=states, training=training) x = self.dense(x, training=training) if return_state: return x, states else: return x @tf.function def train_step(self, inputs): # 解包数据 inputs, labels = inputs with tf.GradientTape() as tape: predictions = self(inputs, training=True) # 前向传递 # 计算损失值 # (损失函数在`compile()`中配置) loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses) # 计算梯度 grads=tape.gradient(loss, model.trainable_variables) # 更新权重 self.optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 更新指标(包括跟踪损失的指标) self.compiled_metrics.update_state(labels, predictions) # 返回一个将指标名称映射到当前值的字典 return {m.name: m.result() for m in self.metrics}
回答:
这些是不同的概念,使用方式如下:
train_step
由fit
调用。基本上,fit
会循环遍历数据集,并将每个批次提供给train_step
(当然还会处理指标、记录等)。call
在你调用模型时使用。确切地说,编写model(inputs)
或在你的情况下self(inputs)
将使用__call__
函数,但Model
类定义了该函数,使其反过来使用call
。
这些是技术方面的。直观上:
call
应该定义模型的前向传递。即输入如何转换为输出。train_step
定义训练步骤的逻辑,通常使用梯度下降。它经常会使用call
,因为训练步骤往往包括模型的前向传递以计算梯度。
关于你链接的GAN教程,我认为那实际上可以被认为是不完整的。它在没有定义call
的情况下工作,因为自定义的train_step
明确调用了生成器/判别器字段(因为这些是预定义的模型,可以像往常一样调用)。如果你尝试像gan(inputs)
那样调用GAN模型,我假设你会得到一个错误消息(我没有测试这一点)。所以你总是必须调用gan.generator(inputs)
来生成,例如。
最后(这部分可能有点 confusing),请注意你可以子类化Model
来定义自定义训练步骤,但随后通过功能API初始化它(如model = Model(inputs, outputs)
),在这种情况下,你可以在训练步骤中使用call
,而无需自己定义它,因为功能API会处理这一点。