nitred
nitred

Reputation: 5609

How to check if `compute_gradients` operation has been executed in tensorflow graph?

Here's my use-case
I am trying to implement the Model Agnostic Meta Learning algorithm. At some phase of the training process I need to calculate the gradients of some variables without actually updating the variables and at a later step I would like to do certain things ONLY if the compute gradient operations are complete.

A simple way to do this is to use tf.control_dependencies()

# In this step I would like to COMPUTE gradients
optimizer = tf.train.AdamOptimizer()
# let's assume that I already have loss and var_list
gradients = optimizer.compute_gradients(loss, var_list)

# In this step I would like to do some things ONLY if the gradients are computed
with tf.control_dependencies([gradients]):
    # do some stuff

Problem
Unfortunately the above snippet throws an error since tf.control_dependencies expects gradients to be a tf.Operation or tf.Tensor but compute_gradients returns a list.

Error message:
TypeError: Can not convert a list into a Tensor or Operation.

What I would like?
I would like one of two things:

Upvotes: 0

Views: 349

Answers (1)

Y. Luo
Y. Luo

Reputation: 5732

Since gradients is the list of (gradient, variable) pairs you'd like to make sure being calculated, you can covert it to a list of tensors/variables and use it as the control_inputs:

with tf.control_dependencies([t for tup in gradients for t in tup]):

Upvotes: 1

Related Questions