zergylord
zergylord

Reputation: 4436

How to directly set the gradient of a layer before backpropagation?

Imagine a tiny network defined as follows, where linear is a typical helper function defining TensorFlow variables for a weight matrix and activation function:

final_layer = linear(linear(_input,10,tf.nn.tanh),20)

Normally this would be optimized via gradient descent on a loss:

loss = tf.reduce_sum(tf.square(final_layer - _target)) train_step = tf.train.AdamOptimizer().minimmize(loss)

But assume I'm getting the derivatives of the loss w.r.t. final_layer from an external source (e.g. a tf.placeholder named _deriv). How can I use this gradient information with one of the builtin optimizers to backpropagate and update the network parameters?

The workaround I'm currently using is to construct an artificial loss consisting of the inner product between _deriv and final_layer (since the derivatives of this loss w.r.t. final_layer will be equal to _deriv).

loss = tf.reduce_sum(final_layer*_deriv) train_step = tf.train.AdamOptimizer().minimmize(loss)

This is very wasteful though, as it needs to do this unnecessary inner product and calculate its derivative for every training step even though I already know this information. Is there a better way?

For those thinking this an odd thing to need to do, it is necessary for implementing synthetic gradients.

Upvotes: 4

Views: 434

Answers (2)

Bas Krahmer
Bas Krahmer

Reputation: 651

For those wondering, a nice way to do this in TensorFlow 2 is customizing what happens in model.fit. Specifically, changing the train_step function to disregard the native GradientTape() and instead passing your externally computed gradients to the optimizer.

Upvotes: 0

lballes
lballes

Reputation: 1502

tf.gradients provides this functionality via its grad_ys argument, see here. In your case, tf.gradients([final_layer], list_of_variables, grad_ys=[_deriv]) would compute the gradients you want.

Unfortunately, it looks like the build-in optimizers don't pass a grad_ys argument to tf.gradients. You might have to hack something into the compute_gradients method of the optimizer class.

Upvotes: 2

Related Questions