Reputation: 1504
I am training a model(VAEGAN) with intermediate outputs and I have two losses,
Can I simply sum them up and apply gradients like below?
with tf.GradientTape() as tape:
z_mean, z_log_sigma, z_encoder_output = self.encoder(real_images, training = True)
kl_loss = self.kl_loss_fn(z_mean, z_log_sigma) * kl_loss_coeff
fake_images = self.decoder(z_encoder_output)
fake_inter_activations, logits_fake = self.discriminator(fake_images, training = True)
real_inter_activations, logits_real = self.discriminator(real_images, training = True)
rec_loss = self.rec_loss_fn(fake_inter_activations, real_inter_activations) * rec_loss_coeff
total_encoder_loss = kl_loss + rec_loss
grads = tape.gradient(total_encoder_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads, self.encoder.trainable_weights))
or do I need to seperate them like below while keeping tape persistent?
with tf.GradientTape(persistent = True) as tape:
z_mean, z_log_sigma, z_encoder_output = self.encoder(real_images, training = True)
kl_loss = self.kl_loss_fn(z_mean, z_log_sigma) * kl_loss_coeff
fake_images = self.decoder(z_encoder_output)
fake_inter_activations, logits_fake = self.discriminator(fake_images, training = True)
real_inter_activations, logits_real = self.discriminator(real_images, training = True)
rec_loss = self.rec_loss_fn(fake_inter_activations, real_inter_activations) * rec_loss_coeff
grads_kl_loss = tape.gradient(kl_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads_kl_loss, self.encoder.trainable_weights))
grads_rec_loss = tape.gradient(rec_loss, self.encoder.trainable_weights)
self.e_optimizer.apply_gradients(zip(grads_rec_loss, self.encoder.trainable_weights))
Upvotes: 0
Views: 783
Reputation: 10473
Yes, you can generally sum the losses and compute a single gradient. Since the gradient of a sum is the sum of the respective gradients, so the step taken by the summed loss is the same as taking both steps one after another.
Here's a simple example: Say you have two weights, and you are currently at the point (1, 3) ("starting point"). The gradient for loss 1 is (2, -4) and the gradient for loss 2 is (1, 2).
Upvotes: 1