mrgloom
mrgloom

Reputation: 21622

GraphKeys.TRAINABLE_VARIABLES vs tf.trainable_variables()

Is GraphKeys.TRAINABLE_VARIABLES is the same as tf.trainable_variables() ?

Is GraphKeys.TRAINABLE_VARIABLES actually tf.GraphKeys.TRAINABLE_VARIABLES?

Looks like networks successfully trains with:

optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    self.train_op = optimizer.minimize(self.loss, var_list=tf.trainable_variables())

but not with

optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
    self.train_op = optimizer.minimize(self.loss)

According to documentation:

var_list: Optional list or tuple of Variable objects to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.

Also as I can see in batch normalization example code var_list is omited:

  x_norm = tf.layers.batch_normalization(x, training=training)

  # ...

  update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
  with tf.control_dependencies(update_ops):
    train_op = optimizer.minimize(loss)

Upvotes: 1

Views: 1161

Answers (1)

Vlad
Vlad

Reputation: 8585

If you don't pass var_list to the minimize() function the variables would be retrieved in a following way (taken from compute_gradients() source code):

if var_list is None:
      var_list = (
          variables.trainable_variables() +
          ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))

If you haven't defined any ResourceVariable instances that somehow not in tf.trainable_variables() the result should be the same. My guess is that the problem is somewhere else.

You could try to perform some test prior to call to minimize() to be sure that you don't have ResourceVariables that are not in tf.trainable_variables():

import tensorflow as tf

with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, shape=[None, 2])
    with tf.name_scope('network'):
        logits = tf.layers.dense(x, units=2)

    var_list = (tf.trainable_variables()
                + tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
    assert set(var_list) == set(tf.trainable_variables())

Upvotes: 1

Related Questions