Rahul Barman
Rahul Barman

Reputation: 13

Can't understand the loss functions for the GAN model used in the tensorflow documentation

I can't understand the loss function in GAN model in tensorflow documentation. Why use tf.ones_like() for real_loss and tf.zeros_like() for fake outputs??

def discriminator_loss(real_output,fake_output):
  real_loss = cross_entropy(tf.ones_like(real_output),real_output)
  fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output)
  total_loss = real_loss + fake_loss
  return total_loss

Upvotes: 1

Views: 509

Answers (1)

Tirth Patel
Tirth Patel

Reputation: 1905

We have the following loss functions we need to minimize in a mini-max fashion (or min-max if you wish to call it that).

  1. generator_loss = -log(generated_labels)
  2. discriminator_loss = -log(real_labels) - log(1 - generated_labels)

where real_output = real_labels and fake_output = generated_labels.

Now, with this in mind, let's see what does the code snippet in TensorFlow's documentation represent:

  • real_loss = cross_entropy(tf.ones_like(real_output), real_output) evaluates to
    • real_loss = -1 * log(real_output) - (1 - 1) * log(1 - real_output) = -log(real_output)
  • fake_loss = cross_entropy(tf.zeros_like(fake_output),fake_output) evaluates to
    • fake_loss = -0 * log(fake_output) - (1 - 0) * log(1 - fake_output) = -log(1 - fake_output)
  • total_loss = real_loss + fake_loss evaluates to
    • total_loss = -log(real_output) - log(1 - fake_output)

Clearly, we get the loss function for the discriminator in the mini-max game we want to minimize.

Upvotes: 1

Related Questions