Reputation: 4521
I'm going through a GAN tutorial and I've noticed the use of 'reuse' flags and I don't quite get what they are doing or preventing. If you take a look at the code below you will see that reuse
is used within each variable scope initialization.
(I tried looking at the docs but still an not clear: https://www.tensorflow.org/versions/r0.12/how_tos/variable_scope/)
def discriminator(images, reuse=False):
"""
Create the discriminator network
"""
alpha = 0.2
with tf.variable_scope('discriminator', reuse=reuse):
# using 4 layer network as in DCGAN Paper
# Conv 1
conv1 = tf.layers.conv2d(images, 64, 5, 2, 'SAME')
lrelu1 = tf.maximum(alpha * conv1, conv1)
# Conv 2
conv2 = tf.layers.conv2d(lrelu1, 128, 5, 2, 'SAME')
batch_norm2 = tf.layers.batch_normalization(conv2, training=True)
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
# Conv 3
conv3 = tf.layers.conv2d(lrelu2, 256, 5, 1, 'SAME')
batch_norm3 = tf.layers.batch_normalization(conv3, training=True)
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
# Flatten
flat = tf.reshape(lrelu3, (-1, 4*4*256))
# Logits
logits = tf.layers.dense(flat, 1)
# Output
out = tf.sigmoid(logits)
return out, logits
def generator(z, out_channel_dim, is_train=True):
"""
Create the generator network
"""
alpha = 0.2
with tf.variable_scope('generator', reuse=False if is_train==True else True):
# First fully connected layer
x_1 = tf.layers.dense(z, 2*2*512)
# Reshape it to start the convolutional stack
deconv_2 = tf.reshape(x_1, (-1, 2, 2, 512))
batch_norm2 = tf.layers.batch_normalization(deconv_2, training=is_train)
lrelu2 = tf.maximum(alpha * batch_norm2, batch_norm2)
# Deconv 1
deconv3 = tf.layers.conv2d_transpose(lrelu2, 256, 5, 2, padding='VALID')
batch_norm3 = tf.layers.batch_normalization(deconv3, training=is_train)
lrelu3 = tf.maximum(alpha * batch_norm3, batch_norm3)
# Deconv 2
deconv4 = tf.layers.conv2d_transpose(lrelu3, 128, 5, 2, padding='SAME')
batch_norm4 = tf.layers.batch_normalization(deconv4, training=is_train)
lrelu4 = tf.maximum(alpha * batch_norm4, batch_norm4)
#Deconv 3
deconv5 = tf.layers.conv2d_transpose(lrelu4, 64, 5, 2, padding='SAME')
batch_norm5 = tf.layers.batch_normalization(deconv5, training=is_train)
lrelu5 = tf.maximum(alpha * batch_norm5, batch_norm5)
# Output layer
logits = tf.layers.conv2d_transpose(lrelu5, out_channel_dim, 5, 2, padding='SAME')
out = tf.tanh(logits)
return out
Thank you.
Upvotes: 0
Views: 720
Reputation: 1669
For the generator, we're going to train it, but also sample from it as we're training and after training. The discriminator will need to share variables between the fake and real input images. So, we can use the reuse keyword for tf.variable_scope to tell TensorFlow to reuse the variables instead of creating new ones if we build the graph again.
Then the discriminators. We'll build two of them, one for real data and one for fake data. Since we want the weights to be the same for both real and fake data, we need to reuse the variables. For the fake data, we're getting it from the generator as g_model. So the real data discriminator is discriminator(input_real) while the fake discriminator is discriminator(g_model, reuse=True).
Upvotes: 2