Reputation: 312
I am attempting to train a discriminator network by applying gradients to its optimizer. However, when I use a tf.GradientTape to find the gradients of loss w.r.t training variables, None is returned. Here is the training loop:
def train_step():
#Generate noisy seeds
noise = tf.random.normal([BATCH_SIZE, noise_dim])
with tf.GradientTape() as disc_tape:
pattern = generator(noise)
pattern = tf.reshape(tensor=pattern, shape=(28,28,1))
dataset = get_data_set(pattern)
disc_loss = tf.Variable(shape=(1,2), initial_value=[[0,0]], dtype=tf.float32)
disc_tape.watch(disc_loss)
for batch in dataset:
disc_loss.assign_add(discriminator(batch, training=True))
disc_gradients = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
The generator network generates a 'pattern' from noise. I then generate a dataset from that pattern by applying various convolutions to the tensor. The dataset that is returned is batched, so I iterate through the dataset and keep track of the loss of my discriminator by adding the loss from this batch to the total loss.
tf.GradientTape returns None when there is no graph connection between the two variables. But isn't there a graph connection between loss and trainable variables? I believe my mistake has something to do with how I keep track of loss in the disc_loss tf.Variable
How do I keep track of loss while iterating through a batched dataset so that I may use it later to calculate gradients?
Upvotes: 1
Views: 114
Reputation: 312
The base answer here is that the assign_add function of tf.Variable is not differentiable, thus no gradient can be calculated between the variable disc_loss and the discriminator trainable variables.
In this very specific case, the answer was
disc_loss = disc_loss + discriminator(batch, training=True)
In future cases of similar problems, be sure to check that all operations used while being watched by the gradient tape are differentiable.
This link has a list of differentiable and non-differentiable tensorflow ops. I found it very useful.
Upvotes: 1