fegemo
fegemo

Reputation: 2594

In a GAN with custom training loop, how can I train the discriminator more times than the generator (such as in WGAN) in tensorflow

I'm working in a code based on the [Pix2Pix tensorflow tutorial][tutorial] and I'm trying to follow the Wasserstein GAN (WGAN) requirements: (a) weight clipping, (b) linear activation for the discriminator, (c) Wasserstein loss, and (d) training the discriminator multiple times for each generator step.

I have a custom training loop, using two Gradient tapes (such as in the tutorial). The code for the training step looks like this:

@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

My question: how can I adapt the code to train the discriminator multiple times for each one that I train the generator?

Upvotes: 1

Views: 831

Answers (1)

Sascha Kirch
Sascha Kirch

Reputation: 524

You could use seperate gradient tapes for the generator and the discriminator training and loop several times over the discriminator.

@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape:
    gen_output = generator(input_image, training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)
    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
  generator_gradients = gen_tape.gradient(gen_total_loss, generator.trainable_variables)
  generator_optimizer.apply_gradients(zip(generator_gradients, generator.trainable_variables))

  disc_train_iterations = 5
  for i in range(disc_train_iterations):
    with tf.GradientTape() as disc_tape:
      gen_output = generator(input_image, training=True)
      disc_real_output = discriminator([input_image, target], training=True)
      disc_generated_output = discriminator([input_image, gen_output], training=True)
      disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
    discriminator_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
    discriminator_optimizer.apply_gradients(zip(discriminator_gradients, discriminator.trainable_variables))

Upvotes: 2

Related Questions