Reputation: 2594
I'm working in a code based on the [Pix2Pix tensorflow tutorial][tutorial] and I'm trying to follow the Wasserstein GAN (WGAN) requirements: (a) weight clipping, (b) linear activation for the discriminator, (c) Wasserstein loss, and (d) training the discriminator multiple times for each generator step.
I have a custom training loop, using two Gradient tapes (such as in the tutorial). The code for the training step looks like this:
@tf.function
def train_step(input_image, target, step):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
generator_gradients = gen_tape.gradient(gen_total_loss,
generator.trainable_variables)
discriminator_gradients = disc_tape.gradient(disc_loss,
discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients,
generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
discriminator.trainable_variables))
My question: how can I adapt the code to train the discriminator multiple times for each one that I train the generator?
Upvotes: 1
Views: 831
Reputation: 524
You could use seperate gradient tapes for the generator and the discriminator training and loop several times over the discriminator.
@tf.function
def train_step(input_image, target, step):
with tf.GradientTape() as gen_tape:
gen_output = generator(input_image, training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))
disc_train_iterations = 5
for i in range(disc_train_iterations):
with tf.GradientTape() as disc_tape:
gen_output = generator(input_image, training=True)
disc_real_output = discriminator([input_image, target], training=True)
disc_generated_output = discriminator([input_image, gen_output], training=True)
disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))
Upvotes: 2