Reputation: 18228
So I am trying to modify a couple of the optimization algos in Keras, namely Adam or just SGD. So by default I'm pretty sure the way the parameter updates work is that the loss is averaged over the data points in the batch, and then a gradient vector is calculated based on this loss value. The other way to think of it is that you average the gradients with respect to the loss values for each data point in the batch. This is the calculation I want to change, and it is gonna be expensive, so I am trying to do it inside the optimized framework that uses GPU and all that.
So, for every batch, I need to calculate the gradient with respect to the loss for each data point in the batch, then instead of taking the mean of the gradients, I will do some other average or calculation. Does anyone know how I would get access to override this functionality of Adam or SGD?
After a great comment I found that there should be a way to do what I am trying to do with the jacobian
method from GradientTape
. However the documentation isn't so thorough, I can't figure out how it fits into the overall picture. Here I am hoping someone can help me make an adjustment to the code to use jacobian
instead of gradient
.
As a hello world example I am trying to simply replace the gradient
line with some code that uses jacobian
and produce the same output. This will illustrate how to use the jacobian
method and the connection with the output from the gradient
method.
Working Code
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars) # <-- line to change
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
Upvotes: 2
Views: 211
Reputation: 10474
You should be able to do the following:
class CustomModel(keras.Model):
def train_step(self, data):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.jacobian(loss, trainable_vars)
new_gradients = []
for grad in gradients:
new_grad = do_something_to(grad)
new_gradients.append(new_grad)
# Update weights
self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
Some important notes: loss
returned by the compiled_loss
function must not average over the batch axis, i.e. I'm assuming it is a tensor of shape (batch_size, )
, not a scalar.
This will cause the jacobian to return gradients of the shape (batch_size, ) + variable_shape
, that is, you now have per-batch-element gradients. You can now manipulate these gradients however you want, and should at some point get rid of the additional batch axis of course (e.g. averaging). That is, new_grad
should have the same shape as the corresponding variable.
Regarding your last comment: As I mentioned, the loss function indeed needs to return one loss per data point, i.e. must not average over the batch. However, this is not enough because if you were to give this vector to tape.gradient
, the gradient function will simply sum up the loss values (since it only works with scalars). This is why jacobian
is necessary.
Finally, jacobian
can be very slow. In the worst case, run time may be multiplied by batch size because it needs to compute that many separate gradients. However, this is done in parallel to some degree so the slowdown might not be as bad.
Upvotes: 2