如何在使用批处理数据集训练网络时跟踪总损失?

我正在尝试通过对其优化器应用梯度来训练一个判别器网络。然而,当我使用tf.GradientTape来查找损失相对于训练变量的梯度时,返回的是None。以下是训练循环:

def train_step():  #生成噪声种子  noise = tf.random.normal([BATCH_SIZE, noise_dim])  with tf.GradientTape() as disc_tape:    pattern = generator(noise)    pattern = tf.reshape(tensor=pattern, shape=(28,28,1))    dataset = get_data_set(pattern)    disc_loss = tf.Variable(shape=(1,2), initial_value=[[0,0]], dtype=tf.float32)    disc_tape.watch(disc_loss)    for batch in dataset:        disc_loss.assign_add(discriminator(batch, training=True))  disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

代码描述

生成器网络从噪声中生成一个“模式”。然后我通过对张量应用各种卷积来从该模式生成数据集。返回的数据集是批处理的,因此我遍历数据集,并通过将此批次的损失加到总损失上来跟踪我的判别器的损失。

我所知道的

当两个变量之间没有图形连接时,tf.GradientTape会返回None。但损失和可训练变量之间不是有图形连接吗?我认为我的错误与我如何在disc_loss tf.Variable中跟踪损失有关

我的问题

在遍历批处理数据集时,如何跟踪损失,以便稍后用于计算梯度?


回答:

这里的基本答案是,tf.Variable的assign_add函数是不可微分的,因此无法计算变量disc_loss与判别器可训练变量之间的梯度。

在这种非常具体的情况下,答案是

disc_loss = disc_loss + discriminator(batch, training=True)

在将来遇到类似问题时,请确保在被梯度带监视时使用的所有操作都是可微分的。

这个链接有一份可微分和不可微分的TensorFlow操作列表。我发现它非常有用。

Related Posts

使用LSTM在Python中预测未来值

这段代码可以预测指定股票的当前日期之前的值,但不能预测…

如何在gensim的word2vec模型中查找双词组的相似性

我有一个word2vec模型,假设我使用的是googl…

dask_xgboost.predict 可以工作但无法显示 – 数据必须是一维的

我试图使用 XGBoost 创建模型。 看起来我成功地…

ML Tuning – Cross Validation in Spark

我在https://spark.apache.org/…

如何在React JS中使用fetch从REST API获取预测

我正在开发一个应用程序,其中Flask REST AP…

如何分析ML.NET中多类分类预测得分数组?

我在ML.NET中创建了一个多类分类项目。该项目可以对…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注