何时调用模型的call()和train_step()?

我正在学习如何自定义训练循环的教程

https://colab.research.google.com/github/tensorflow/docs/blob/snapshot-keras/site/en/guide/keras/customizing_what_happens_in_fit.ipynb#scrollTo=46832f2077ac

最后一个例子展示了一个使用自定义训练实现的GAN,其中只定义了__init__train_stepcompile方法

class GAN(keras.Model):    def __init__(self, discriminator, generator, latent_dim):        super(GAN, self).__init__()        self.discriminator = discriminator        self.generator = generator        self.latent_dim = latent_dim    def compile(self, d_optimizer, g_optimizer, loss_fn):        super(GAN, self).compile()        self.d_optimizer = d_optimizer        self.g_optimizer = g_optimizer        self.loss_fn = loss_fn    def train_step(self, real_images):        if isinstance(real_images, tuple):            real_images = real_images[0]        ...

如果我的模型也有一个自定义的call()函数会怎样?train_step()会覆盖call()吗?call()train_step()不是都被fit()调用吗?两者之间有什么区别?

下面是我写的另一段代码,我在想fit()中调用的是call()还是train_step()

class MyModel(tf.keras.Model):  def __init__(self, vocab_size, embedding_dim, rnn_units):    super().__init__(self)    self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)    self.gru = tf.keras.layers.GRU(rnn_units,                                   return_sequences=True,                                   return_state=True,                                   reset_after=True                                   )    self.dense = tf.keras.layers.Dense(vocab_size)  def call(self, inputs, states=None, return_state=False, training=False):    x = inputs    x = self.embedding(x, training=training)    if states is None:      states = self.gru.get_initial_state(x)    x, states = self.gru(x, initial_state=states, training=training)    x = self.dense(x, training=training)    if return_state:      return x, states    else:      return x  @tf.function  def train_step(self, inputs):    # 解包数据    inputs, labels = inputs      with tf.GradientTape() as tape:      predictions = self(inputs, training=True) # 前向传递      # 计算损失值      # (损失函数在`compile()`中配置)      loss=self.compiled_loss(labels, predictions, regularization_losses=self.losses)    # 计算梯度    grads=tape.gradient(loss, model.trainable_variables)    # 更新权重    self.optimizer.apply_gradients(zip(grads, model.trainable_variables))    # 更新指标(包括跟踪损失的指标)    self.compiled_metrics.update_state(labels, predictions)    # 返回一个将指标名称映射到当前值的字典    return {m.name: m.result() for m in self.metrics}

回答:

这些是不同的概念,使用方式如下:

  • train_stepfit调用。基本上,fit会循环遍历数据集,并将每个批次提供给train_step(当然还会处理指标、记录等)。
  • call在你调用模型时使用。确切地说,编写model(inputs)或在你的情况下self(inputs)将使用__call__函数,但Model类定义了该函数,使其反过来使用call

这些是技术方面的。直观上:

  • call应该定义模型的前向传递。即输入如何转换为输出。
  • train_step定义训练步骤的逻辑,通常使用梯度下降。它经常会使用call,因为训练步骤往往包括模型的前向传递以计算梯度。

关于你链接的GAN教程,我认为那实际上可以被认为是不完整的。它在没有定义call的情况下工作,因为自定义的train_step明确调用了生成器/判别器字段(因为这些是预定义的模型,可以像往常一样调用)。如果你尝试像gan(inputs)那样调用GAN模型,我假设你会得到一个错误消息(我没有测试这一点)。所以你总是必须调用gan.generator(inputs)来生成,例如。

最后(这部分可能有点 confusing),请注意你可以子类化Model来定义自定义训练步骤,但随后通过功能API初始化它(如model = Model(inputs, outputs)),在这种情况下,你可以在训练步骤中使用call,而无需自己定义它,因为功能API会处理这一点。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注