Reputation: 5510
I'm trying to define a gradient method for my custom TF operation. Most of the solutions I have found online seem to based on a gist by harpone. I'm reluctant to use that approach as it uses py_func
which won't run on GPU. I found another solution here that uses tf.identity()
that looks more elegant and I think will run on GPU. However, I have some problems accessing inputs of the ops in my custom gradient function. Here's my code:
@tf.RegisterGradient('MyCustomGradient')
def _custom_gradient(op, gradients):
x = op.inputs[0]
return(x)
def my_op(w):
return tf.pow(w,3)
var_foo = tf.Variable(5, dtype=tf.float32)
bar = my_op(var_foo)
g = tf.get_default_graph()
with g.gradient_override_map({'Identity': 'MyCustomGradient'}):
bar = tf.identity(bar)
g = tf.gradients(bar, var_foo)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(g))
I was expecting _custom_gradient()
to return the input to the op (5 in this example) but instead it seems to return op output x gradient
. My custom my_op will have non-differentiable operations like tf.sign and I'd like to define my custom gradient based on the inputs. What am I doing wrong?
Upvotes: 0
Views: 433
Reputation: 2019
There is no problem with your code:
Let's first do the forward pass:
var_foo = 5
-> bar = 125
-> tf.identity(bar) = 125
Now let's backpropagate:
The gradient of tf.identity(bar)
with respect to its argument bar
equals (by your definition) to bar
, that is, 125
. The gradient of bar
with respect to var_foo
equals 3 times the square of var_foo
which is 75
. Multiply, and you get 9375
, which is indeed the output of your code.
op.inputs[0]
contains the forward-pass value of the op. In this case, the forward pass of the identity
op is 125
.
Upvotes: 2