jeffery_the_wind
jeffery_the_wind

Reputation: 18228

How to override gradient vector calculation method for optimization algos in Keras, Tensorflow?

So I am trying to modify a couple of the optimization algos in Keras, namely Adam or just SGD. So by default I'm pretty sure the way the parameter updates work is that the loss is averaged over the data points in the batch, and then a gradient vector is calculated based on this loss value. The other way to think of it is that you average the gradients with respect to the loss values for each data point in the batch. This is the calculation I want to change, and it is gonna be expensive, so I am trying to do it inside the optimized framework that uses GPU and all that.

So, for every batch, I need to calculate the gradient with respect to the loss for each data point in the batch, then instead of taking the mean of the gradients, I will do some other average or calculation. Does anyone know how I would get access to override this functionality of Adam or SGD?

After a great comment I found that there should be a way to do what I am trying to do with the jacobian method from GradientTape. However the documentation isn't so thorough, I can't figure out how it fits into the overall picture. Here I am hoping someone can help me make an adjustment to the code to use jacobian instead of gradient.

As a hello world example I am trying to simply replace the gradient line with some code that uses jacobian and produce the same output. This will illustrate how to use the jacobian method and the connection with the output from the gradient method.

Working Code

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars) # <-- line to change
        # Update weights
        self.optimizer.apply_gradients(zip(gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

Upvotes: 2

Views: 211

Answers (1)

xdurch0
xdurch0

Reputation: 10474

You should be able to do the following:

class CustomModel(keras.Model):
    def train_step(self, data):
        # Unpack the data. Its structure depends on your model and
        # on what you pass to `fit()`.
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)  # Forward pass
            # Compute the loss value
            # (the loss function is configured in `compile()`)
            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        # Compute gradients
        trainable_vars = self.trainable_variables
        gradients = tape.jacobian(loss, trainable_vars)

        new_gradients = []
        for grad in gradients:
            new_grad = do_something_to(grad)
            new_gradients.append(new_grad)

        # Update weights
        self.optimizer.apply_gradients(zip(new_gradients, trainable_vars))
        # Update metrics (includes the metric that tracks the loss)
        self.compiled_metrics.update_state(y, y_pred)
        # Return a dict mapping metric names to current value
        return {m.name: m.result() for m in self.metrics}

Some important notes: loss returned by the compiled_loss function must not average over the batch axis, i.e. I'm assuming it is a tensor of shape (batch_size, ), not a scalar.
This will cause the jacobian to return gradients of the shape (batch_size, ) + variable_shape, that is, you now have per-batch-element gradients. You can now manipulate these gradients however you want, and should at some point get rid of the additional batch axis of course (e.g. averaging). That is, new_grad should have the same shape as the corresponding variable.

Regarding your last comment: As I mentioned, the loss function indeed needs to return one loss per data point, i.e. must not average over the batch. However, this is not enough because if you were to give this vector to tape.gradient, the gradient function will simply sum up the loss values (since it only works with scalars). This is why jacobian is necessary.

Finally, jacobian can be very slow. In the worst case, run time may be multiplied by batch size because it needs to compute that many separate gradients. However, this is done in parallel to some degree so the slowdown might not be as bad.

Upvotes: 2

Related Questions