这是我在tensorflow中训练GAN的代码。我正在训练一个判别器来区分假视频和原始视频。我已经省略了不相关的代码部分,以避免堆栈溢出,主要是代码错误
X = tf.placeholder(tf.float32, shape=[None, 28, 28])D_W1 = tf.Variable(xavier_init([1024, 128]))D_b1 = tf.Variable(tf.zeros(shape=[128]))D_W2 = tf.Variable(xavier_init([128, 1]))D_b2 = tf.Variable(tf.zeros(shape=[1]))theta_D = [D_W1, D_W2, D_b1, D_b2]rnn_size = 1024rnn_layer = 2Z = tf.placeholder(tf.float32, shape=[None, 100])G_W1 = tf.Variable(xavier_init([100, 128]))G_b1 = tf.Variable(tf.zeros(shape=[128]))G_W2 = tf.Variable(xavier_init([128, 784]))G_b2 = tf.Variable(tf.zeros(shape=[784]))theta_G = [G_W1, G_W2, G_b1, G_b2]def sample_Z(m, n): return np.random.uniform(-1., 1., size=[m, n])def generator(z): G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1) G_log_prob = tf.matmul(G_h1, G_W2) + G_b2 G_prob = tf.nn.sigmoid(G_log_prob) G_prob = tf.reshape(G_prob, [-1,28, 28]) return G_probdef discriminator(x): x = [tf.squeeze(t, [1]) for t in tf.split(x, 28, 1)] # with tf.variable_scope('cell_def'): stacked_rnn1 = [] for iiLyr1 in range(rnn_layer): stacked_rnn1.append(tf.nn.rnn_cell.BasicLSTMCell(num_units=rnn_size, state_is_tuple=True)) lstm_multi_fw_cell = tf.contrib.rnn.MultiRNNCell(cells=stacked_rnn1) # with tf.variable_scope('rnn_def'): dec_outputs, dec_state = tf.contrib.rnn.static_rnn( lstm_multi_fw_cell, x, dtype=tf.float32) D_h1 = tf.nn.relu(tf.matmul(dec_outputs[-1], D_W1) + D_b1) D_logit = tf.matmul(D_h1, D_W2) + D_b2 D_prob = tf.nn.sigmoid(D_logit) return D_prob, D_logitG_sample = generator(Z)print(G_sample.get_shape())print(X.get_shape())D_real, D_logit_real = discriminator(X)D_fake, D_logit_fake = discriminator(G_sample)D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))G_loss = -tf.reduce_mean(tf.log(D_fake))summary_d = tf.summary.histogram('D_loss histogram', D_loss)summary_g = tf.summary.histogram('D_loss histogram', G_loss)summary_s = tf.summary.scalar('D_loss scalar', D_loss)summary_s1 = tf.summary.scalar('scalar scalar', G_loss)# Add image summarysummary_op = tf.summary.image("plot", image)D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)mb_size = 128Z_dim = 100mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)# merged_summary_op = tf.summary.merge_all()sess = tf.Session()saver = tf.train.Saver()writer1 = tf.summary.FileWriter('log/log-sample1', sess.graph)writer2 = tf.summary.FileWriter('log/log-sample2', sess.graph)sess.run(tf.global_variables_initializer())if not os.path.exists('out/'): os.makedirs('out/')i = 0with tf.variable_scope("myrnn") as scope: for it in range(5000): X_mb, _ = mnist.train.next_batch(mb_size) X_mb = tf.reshape(X_mb, [mb_size, -1, 28]) _, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)}) _, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)}) summary_str, eded = sess.run([summary_d, summary_s], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)}) writer1.add_summary(summary_str, it) writer1.add_summary(eded, it) summary_str1, eded1 = sess.run([summary_g, summary_s1], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)}) writer2.add_summary(summary_str1, it) writer2.add_summary(eded1, it) if it % 1000 == 0: print('Iter: {}'.format(it)) print('D loss: {:.4}'. format(D_loss_curr)) print('G_loss: {:.4}'.format(G_loss_curr)) print() save_path = saver.save(sess, "tmp/model.ckpt") writer1.close() writer2.close()`
运行此代码时出现以下错误,请帮助解决。
Traceback (most recent call last): File "/Users/tulsijain/Desktop/Deep Learning Practise/GAN/vanila.py", line 104, in <module> D_fake, D_logit_fake = discriminator(G_sample) File "/Users/tulsijain/Desktop/Deep Learning Practise/GAN/vanila.py", line 64, in discriminator lstm_multi_fw_cell, x, dtype=tf.float32) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 1212, in static_rnn (output, state) = call_cell() File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn.py", line 1199, in <lambda> call_cell = lambda: cell(input_, state) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__ return super(RNNCell, self).__call__(inputs, state) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 441, in __call__ outputs = self.call(inputs, *args, **kwargs) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 916, in call cur_inp, new_state = cell(cur_inp, cur_state) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 180, in __call__ return super(RNNCell, self).__call__(inputs, state) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/layers/base.py", line 441, in __call__ outputs = self.call(inputs, *args, **kwargs) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 383, in call concat = _linear([inputs, h], 4 * self._num_units, True) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 1017, in _linear initializer=kernel_initializer) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 1065, in get_variable use_resource=use_resource, custom_getter=custom_getter) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 962, in get_variable use_resource=use_resource, custom_getter=custom_getter) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 360, in get_variable validate_shape=validate_shape, use_resource=use_resource) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 1405, in wrapped_custom_getter *args, **kwargs) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in _rnn_get_variable variable = getter(*args, **kwargs) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/rnn_cell_impl.py", line 183, in _rnn_get_variable variable = getter(*args, **kwargs) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 352, in _true_getter use_resource=use_resource) File "/Library/Frameworks/Python.framework/Versions/3.5/lib/python3.5/site-packages/tensorflow/python/ops/variable_scope.py", line 664, in _get_single_variable name, "".join(traceback.format_list(tb))))ValueError: Variable rnn/multi_rnn_cell/cell_0/basic_lstm_cell/kernel already exists, disallowed. Did you mean to set reuse=True in VarScope? Originally defined at: File "/Users/tulsijain/Desktop/Deep Learning Practise/GAN/vanila.py", line 64, in discriminator lstm_multi_fw_cell, x, dtype=tf.float32) File "/Users/tulsijain/Desktop/Deep Learning Practise/GAN/vanila.py", line 103, in <module> D_real, D_logit_real = discriminator(X)
这是GAN。我使用MNIST数据来训练生成器和判别器。
回答:
在BasicLSTMCell
中添加一个reuse
参数。由于你两次调用了discriminator
函数,并且每次都设置reuse=None
,因此在尝试创建同名变量时会抛出错误。在这种情况下,你需要在第二次调用时重用图中的变量,因为你不需要创建一组新的变量。
def discriminator(x, reuse): x = [tf.squeeze(t, [1]) for t in tf.split(x, 28, 1)] # with tf.variable_scope('cell_def'): stacked_rnn1 = [] for iiLyr1 in range(rnn_layer): stacked_rnn1.append(tf.nn.rnn_cell.BasicLSTMCell(num_units=rnn_size, state_is_tuple=True, reuse=reuse)) lstm_multi_fw_cell = tf.contrib.rnn.MultiRNNCell(cells=stacked_rnn1) # with tf.variable_scope('rnn_def'): dec_outputs, dec_state = tf.contrib.rnn.static_rnn( lstm_multi_fw_cell, x, dtype=tf.float32) D_h1 = tf.nn.relu(tf.matmul(dec_outputs[-1], D_W1) + D_b1) D_logit = tf.matmul(D_h1, D_W2) + D_b2 D_prob = tf.nn.sigmoid(D_logit) return D_prob, D_logit....D_real, D_logit_real = discriminator(X, None)D_fake, D_logit_fake = discriminator(G_sample, True)....