Tensorflow – tf.variable_scope , GAN中reuse参数的使用

我在做一个项目,需要构建一个生成对抗网络(GAN),我非常想理解TensorFlow中variable_scope的变量共享是如何工作的。

在构建GAN时,我有一个生成器网络和两个判别器网络:一个判别器接受真实图像,另一个判别器接受由生成器创建的假图像。重要的是,接受真实图像的判别器和接受假图像的判别器需要共享相同的权重。为了实现这一点,我需要共享权重。

我有一个判别器和生成器的定义,比如说:

def discriminator(images, reuse=False):    with variable_scope("discriminator", reuse=reuse):        #.... 层定义,这里不重要        #....        logits = tf.layers.dense(X, 1)        logits = tf.identity(logits, name="logits")        out = tf.sigmoid(logits, name="out")        # 14x14x64    return logits, outdef generator(input_z, reuse=False):    with variable_scope("generator", reuse=reuse):        #.. 不太重要         out = tf.tanh(logits)    return out

现在生成器和判别器函数被调用:

g_model = generator(input_z)d_model_real, d_logits_real = discriminator(input_real)#在这里,reuse=True应该会在d_model_real, d_logits_real#和d_model_fake, d_logits_fake之间产生权重共享.. 为什么?d_model_fake, d_logits_fake = discriminator(g_model, reuse=True)

为什么在第二次调用时,reuse=True会导致权重共享?据我所知,你需要在第一次调用时决定变量的重用,以便在程序的其他地方使用它们。

如果有人能向我解释这一点,我会非常高兴,我找不到关于这个主题的好资料,这对我来说似乎非常 confusing 和复杂。谢谢!


回答:

在底层,变量是使用tf.get_variable()创建的。

这个函数会为变量名称加上前缀,并在创建新变量之前检查它是否存在。

例如,如果你在"fc"作用域内调用tf.get_variable("w", [10,10]),变量名称将是"fc/w:0"

现在,当你第二次这样做时,如果reuse=True,作用域仍然是"fc",get_variable将重用变量"fc/w:0"

然而,如果reuse=False,你会得到一个错误,因为变量"fc/w:0"已经存在,提示你使用不同的名称或使用reuse=True

示例:

In [1]: import tensorflow as tfIn [2]: with tf.variable_scope("fc"):   ...:      v = tf.get_variable("w", [10,10])   ...:In [3]: vOut[3]: <tf.Variable 'fc/w:0' shape=(10, 10) dtype=float32_ref>In [4]: with tf.variable_scope("fc"):   ...:      v = tf.get_variable("w", [10,10])   ...:ValueError: Variable fc/w already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?In [5]: with tf.variable_scope("fc", reuse=True):   ...:      v = tf.get_variable("w", [10,10])   ...:In [6]: vOut[6]: <tf.Variable 'fc/w:0' shape=(10, 10) dtype=float32_ref>

请注意,除了共享权重,你也可以只实例化一个判别器。然后你可以使用placeholder_with_default决定是向它输入真实数据还是生成数据。

Related Posts

Keras Dense层输入未被展平

这是我的测试代码: from keras import…

无法将分类变量输入随机森林

我有10个分类变量和3个数值变量。我在分割后直接将它们…

如何在Keras中对每个输出应用Sigmoid函数?

这是我代码的一部分。 model = Sequenti…

如何选择类概率的最佳阈值?

我的神经网络输出是一个用于多标签分类的预测类概率表: …

在Keras中使用深度学习得到不同的结果

我按照一个教程使用Keras中的深度神经网络进行文本分…

‘MatMul’操作的输入’b’类型为float32,与参数’a’的类型float64不匹配

我写了一个简单的TensorFlow代码,但不断遇到T…

发表回复

您的邮箱地址不会被公开。 必填项已用 * 标注