D.Badawi
D.Badawi

Reputation: 175

Tensorflow: Don't Update if gradient is Nan

I have a deep model to train on CIFAR-10. Training works fine with CPU. However, when I use GPU support, it causes gradients for some batches to be NaNs (I checked it using tf.check_numerics) and it happens randomly but early enough. I believe the problem is related to my GPU.

My question is that: is there away not to update if at least one of the gradients has NaNs and force the model to proceed to the next batch ?

Edit: Perhaps I should elaborate more on my problem.

This is how I apply the gradients:

with tf.control_dependencies([tf.check_numerics(grad, message='Gradient %s check failed, possible NaNs' % var.name) for grad, var in grads]):
# Apply the gradients to adjust the shared variables.
  apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)

I have thought of using tf.check_numerics first to verify that there are Nans in the gradients, and, then, if there are Nans (check failed) I can "pass" without using opt.apply_gradients. However, is there a way to catch an error with tf.control_dependencies ?

Upvotes: 1

Views: 1109

Answers (1)

D.Badawi
D.Badawi

Reputation: 175

I could figure it out, albeit not in the most elegant way. My solution is as follows: 1) check all gradients first 2) if gradients are NaNs-free, apply them 3) otherwise, apply fake update (with zero values), this needs gradient override.

This is my code:

First define custom gradient:

@tf.RegisterGradient("ZeroGrad")
def _zero_grad(unused_op, grad):
  return tf.zeros_like(grad)

Then define an exception-handling function:

#this is added for gradient check of NaNs
def check_numerics_with_exception(grad, var):
  try:
   tf.check_numerics(grad, message='Gradient %s check failed, possible NaNs' % var.name)
  except:
    return tf.constant(False, shape=())
  else:
    return tf.constant(True, shape=())  

Then create conditional node:

num_nans_grads = tf.Variable(1.0, name='num_nans_grads')
check_all_numeric_op = tf.reduce_sum(tf.cast(tf.stack([tf.logical_not(check_numerics_with_exception(grad, var)) for grad, var in grads]), dtype=tf.float32))

with tf.control_dependencies([tf.assign(num_nans_grads, check_all_numeric_op)]):
# Apply the gradients to adjust the shared variables.
  def fn_true_apply_grad(grads, global_step):
    apply_gradients_true = opt.apply_gradients(grads, global_step=global_step)
    return apply_gradients_true

  def fn_false_ignore_grad(grads, global_step):
   #print('batch update ignored due to nans, fake update is applied')
   g = tf.get_default_graph()
   with g.gradient_override_map({"Identity": "ZeroGrad"}):
     for (grad, var) in grads:
       tf.assign(var, tf.identity(var, name="Identity"))
       apply_gradients_false = opt.apply_gradients(grads, global_step=global_step)
   return apply_gradients_false

  apply_gradient_op = tf.cond(tf.equal(num_nans_grads, 0.), lambda : fn_true_apply_grad(grads, global_step), lambda :  fn_false_ignore_grad(grads, global_step))

Upvotes: 2

Related Questions