Reputation: 179
For a function approximation problem I'm trying to accumulate gradients but I find that sometimes some of these gradients are nan(i.e. undefined) even though the loss is always real. I think this might be due to numerical instabilities and I'm basically looking for a simple method for removing the nans from the computed gradients.
Starting with the solution to this question I tried doing the following:
# Optimizer definition - nothing different from any classical example
opt = tf.train.AdamOptimizer()
## Retrieve all trainable variables you defined in your graph
tvs = tf.trainable_variables()
## Creation of a list of variables with the same shape as the trainable ones
# initialized with 0s
accum_vars = [tf.Variable(tf.zeros_like(tv.initialized_value()), trainable=False) for tv in tvs]
zero_ops = [tv.assign(tf.zeros_like(tv)) for tv in accum_vars]
## Calls the compute_gradients function of the optimizer to obtain... the list of gradients
gvs_ = opt.compute_gradients(rmse, tvs)
gvs =tf.where(tf.is_nan(gvs_), tf.zeros_like(gvs_), gvs_)
## Adds to each element from the list you initialized earlier with zeros its gradient (works because accum_vars and gvs are in the same order)
accum_ops = [accum_vars[i].assign_add(gv[0]) for i, gv in enumerate(gvs)]
## Define the training step (part with variable value update)
train_step = opt.apply_gradients([(accum_vars[i], gv[1]) for i, gv in enumerate(gvs)])
So basically, the key idea is this line:
gvs =tf.where(tf.is_nan(gvs_), tf.zeros_like(gvs_), gvs_)
But when I apply this idea I obtain the following error:
ValueError: Tried to convert 'x' to a tensor and failed. Error: Dimension 1 in both shapes must be equal, but are 30 and 9. Shapes are [2,30] and [2,9]. From merging shape 2 with other shapes. for 'IsNan/packed' (op: 'Pack') with input shapes: [2,9,30], [2,30,9], [2,30], [2,9].
Upvotes: 1
Views: 1580
Reputation: 17191
compute_gradients
returns a list of tensors in your case. You may want to do:
gvs_ = [(tf.where(tf.is_nan(grad), tf.zeros_like(grad), grad), val) for grad,val in gvs_]
Upvotes: 1