eitanrich
eitanrich

Reputation: 321

How to make sure your computation graph is differentiable

Some of the Tensorflow operations (e.g. tf.argmax) are not differentiable (i.e. no gradients are calculated and used in back-propagation).

An answer to Tensorflow what operations are differentiable and what are not? suggests searching for RegisterGradient in the Tensorflow code. I also noticed Tensorflow has a tf.NotDifferentiable API call for declaring an operation to be non-differentiable.

Is there a warning issued if I use non-differentiable functions? Is there a programmatic way to ensure that my entire computation graph is differentiable?

Upvotes: 3

Views: 1756

Answers (1)

Allen Lavoie
Allen Lavoie

Reputation: 5808

Most floating point operations will have gradients, so a first pass answer would just be to check that there are no int32/int64 dtype Tensors in the graph. This is easy to do, but probably not useful (i.e. any non-trivial model will be doing non-differentiable indexing operations).

You could do some type of introspection, looping over the operations in the GraphDef and checking that they have gradients registered. I would argue that this is not terribly useful either; if we don't trust that gradients are registered in the first place, why trust that they're correct if registered?

Instead, I'd go with numerical gradient checking at a few points for your model. For example, let's say we register a PyFunc without a gradient:

import tensorflow as tf
import numpy
def my_func(x):
  return numpy.sinh(x)
with tf.Graph().as_default():
  inp = tf.placeholder(tf.float32)
  y = tf.py_func(my_func, [inp], tf.float32) + inp
  grad, = tf.gradients(y, inp)
  with tf.Session() as session:
    print(session.run([y, grad], feed_dict={inp: 3}))
    print("Gradient error:", tf.test.compute_gradient_error(inp, [], y, []))

This gets me output like:

[13.017875, 1.0]
Gradient error: 1.10916996002

Numerical gradients can be a bit tricky, but generally any gradient error which is more than a few orders of magnitude more than the machine epsilon (~1e-7 for float32) would raise red flags for me for a supposedly smooth function.

Upvotes: 3

Related Questions