Andrew Wiedenmann
Andrew Wiedenmann

Reputation: 312

How should I keep track of total loss while training a network with a batched dataset?

I am attempting to train a discriminator network by applying gradients to its optimizer. However, when I use a tf.GradientTape to find the gradients of loss w.r.t training variables, None is returned. Here is the training loop:

def train_step():
  #Generate noisy seeds
  noise = tf.random.normal([BATCH_SIZE, noise_dim])
  with tf.GradientTape() as disc_tape:
    pattern = generator(noise)
    pattern = tf.reshape(tensor=pattern, shape=(28,28,1))
    dataset = get_data_set(pattern)
    disc_loss = tf.Variable(shape=(1,2), initial_value=[[0,0]], dtype=tf.float32)
    disc_tape.watch(disc_loss)
    for batch in dataset:
        disc_loss.assign_add(discriminator(batch, training=True))

  disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

Code Description

The generator network generates a 'pattern' from noise. I then generate a dataset from that pattern by applying various convolutions to the tensor. The dataset that is returned is batched, so I iterate through the dataset and keep track of the loss of my discriminator by adding the loss from this batch to the total loss.

What I do know

tf.GradientTape returns None when there is no graph connection between the two variables. But isn't there a graph connection between loss and trainable variables? I believe my mistake has something to do with how I keep track of loss in the disc_loss tf.Variable

My Question

How do I keep track of loss while iterating through a batched dataset so that I may use it later to calculate gradients?

Upvotes: 1

Views: 114

Answers (1)

Andrew Wiedenmann
Andrew Wiedenmann

Reputation: 312

The base answer here is that the assign_add function of tf.Variable is not differentiable, thus no gradient can be calculated between the variable disc_loss and the discriminator trainable variables.

In this very specific case, the answer was

disc_loss = disc_loss + discriminator(batch, training=True)

In future cases of similar problems, be sure to check that all operations used while being watched by the gradient tape are differentiable.

This link has a list of differentiable and non-differentiable tensorflow ops. I found it very useful.

Upvotes: 1

Related Questions