bored_to_death
bored_to_death

Reputation: 383

stopping condition on gradient value tensorflow

I would like to implement a stopping condition based on the value of the gradient of the loss function w.r.t. the weights. For example, let's say I have something like this:

optimizer = tf.train.AdamOptimizer()
grads_and_vars = optimizer.compute_gradients(a_loss_function)
train_op = optimizer.apply_gradients(grads_and_vars)

then I would like to run the graph with something like this:

for step in range(TotSteps):
    output = sess.run([input], feed_dict=some_dict)
    if(grad_taken_in_some_way < some_treshold):
        print("Training finished.")
        break

I am not sure what I should pass to sess.run() in order to get as output also the gradient (besides all other stuff I need). I am not even sure whether this is the correct approach or I should do it differently. I made some tries but I failed every time. Hope someone has some hints. Thank you in advance!

EDIT: English correction

EDIT2: Answer by Iballes is exactly what I wanted to do. Still, I am not sure how to norm and sum all the gradients. Since I have different layer in my CNN and different weights with different shape, if I just do what you suggested, I get an error on the add_n() operation (since I am trying to add together matrices with different shapes). So probably I should do something like:

grad_norms = [tf.nn.l2_normalize(g[0], 0) for g in grads_and_vars]      
grad_norm = [tf.reduce_sum(grads) for grads in grad_norms]
final_grad = tf.reduce_sum(grad_norm)

Can anyone confirm this?

Upvotes: 3

Views: 2703

Answers (1)

lballes
lballes

Reputation: 1502

Your line output = sess.run([input], feed_dict=some_dict) makes think that you have a little misunderstanding of the sess.run command. What you call [input] is supposed to be a list of tensors that are to be fetched by the sess.run command. Hence, it is an output rather than an input. To tackle your question, let's assume that you are doing something like output = sess.run(loss, feed_dict=some_dict) instead (in order to monitor the training loss).

Also, I suppose you want to formulate your stopping criterion using the norm of the gradient (the gradient itself is a multi-dimensional quantity). Hence, what you want to do is to fetch the norm of the gradient each time you execute the graph. To that end, you have to do two things. 1) Add the gradient norm to the computation graph. 2) Fetch it in each call to sess.run in your training loop.

Ad 1) You have added the gradients to the graph via

optimizer = tf.train.AdamOptimizer()
grads_and_vars = optimizer.compute_gradients(a_loss_function)

and now have the tensors holding the gradients in grads_and_vars (one for each trained variable in the graph). Let's take the norm of each gradient and then sum it up:

grad_norms = [tf.nn.l2_loss(g) for g, v in grads_and_vars]
grad_norm = tf.add_n(grad_norms)

There you have your gradient norm.

Ad 2) Inside your loop, fetch the gradient norm alongside the loss by telling the sess.run command to do so.

for step in range(TotSteps):
    l, gn = sess.run([loss, grad_norm], feed_dict=some_dict)
    if(gn < some_treshold):
        print("Training finished.")
        break

Upvotes: 3

Related Questions