Reputation: 21622
Is GraphKeys.TRAINABLE_VARIABLES
is the same as tf.trainable_variables()
?
Is GraphKeys.TRAINABLE_VARIABLES
actually tf.GraphKeys.TRAINABLE_VARIABLES
?
Looks like networks successfully trains with:
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = optimizer.minimize(self.loss, var_list=tf.trainable_variables())
but not with
optimizer = tf.train.AdamOptimizer(config.LEARNING_RATE)
with tf.control_dependencies(tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
self.train_op = optimizer.minimize(self.loss)
According to documentation:
var_list: Optional list or tuple of Variable objects to update to minimize loss. Defaults to the list of variables collected in the graph under the key GraphKeys.TRAINABLE_VARIABLES.
Also as I can see in batch normalization example code var_list
is omited:
x_norm = tf.layers.batch_normalization(x, training=training)
# ...
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
train_op = optimizer.minimize(loss)
Upvotes: 1
Views: 1161
Reputation: 8585
If you don't pass var_list
to the minimize()
function the variables would be retrieved in a following way (taken from compute_gradients()
source code):
if var_list is None:
var_list = (
variables.trainable_variables() +
ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
If you haven't defined any ResourceVariable
instances that somehow not in tf.trainable_variables()
the result should be the same. My guess is that the problem is somewhere else.
You could try to perform some test prior to call to minimize()
to be sure that you don't have ResourceVariable
s that are not in tf.trainable_variables()
:
import tensorflow as tf
with tf.Graph().as_default():
x = tf.placeholder(tf.float32, shape=[None, 2])
with tf.name_scope('network'):
logits = tf.layers.dense(x, units=2)
var_list = (tf.trainable_variables()
+ tf.get_collection(tf.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
assert set(var_list) == set(tf.trainable_variables())
Upvotes: 1