Reputation: 3659
I am trying to implement a little tweaked version of the Batch Normalization operation; in which I need to keep the moving average values like mean and variance explicitly. In order to do that, I am doing some experimentation with assignment and control dependency mechanisms in the Tensorflow and I run into a mysterious problem. I have the following toy code; in which I am trying to test whether the tf.control_dependencies
work as intended:
dataset = MnistDataSet(validation_sample_count=10000,
load_validation_from="validation_indices")
samples, labels, indices_list, one_hot_labels =
dataset.get_next_batch(batch_size=GlobalConstants.BATCH_SIZE)
samples = np.expand_dims(samples, axis=3)
flat_data = tf.contrib.layers.flatten(GlobalConstants.TRAIN_DATA_TENSOR)
mean = tf.Variable(name="mean", initial_value=tf.constant(100.0, shape=[784], dtype=tf.float32),
trainable=False, dtype=tf.float32)
a = tf.Variable(name="a", initial_value=5.0, trainable=False)
b = tf.Variable(name="b", initial_value=4.0, trainable=False)
c = tf.Variable(name="c", initial_value=0.0, trainable=False)
batch_mean, batch_var = tf.nn.moments(flat_data, [0])
b_op = tf.assign(b, a)
mean_op = tf.assign(mean, batch_mean)
with tf.control_dependencies([b_op, mean_op]):
c = a + b
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
results = sess.run([c, mean], feed_dict={GlobalConstants.TRAIN_DATA_TENSOR: samples})
I am simply loading a data batch with each entry having 784 dimensions, calculate the moments of it and try to store the batch_mean
into the variable mean
. I trivially store the variable a
's value into b
as well.
In the last line, when I run the graph for the values of c
and mean
, I see c
as 10, which is the expected value. But mean
is still a vector of 100's and does not contain the batch mean. It is like the mean_op = tf.assign(mean, batch_mean)
has not been executed.
What can be the reason of this? As far as I know, all operations in the tf.control_dependencies
call must be executed before any operation in the following context; I explicitly call c
here, which is in the context. Am I missing something?
Upvotes: 3
Views: 1429
Reputation: 53768
This is a known "feature" of tf.Session.run()
. The c
and mean
ops are independent, hence mean
may be evaluated before c
(which would update mean
).
Here's a shorter version of this effect:
a = tf.Variable(name="a", initial_value=1.0, trainable=False)
b = tf.Variable(name="b", initial_value=0.0, trainable=False)
dependent_op = tf.assign(b, a * 3)
with tf.control_dependencies([dependent_op]):
c = a + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run([c, b]))
print(sess.run([b]))
The second evaluation of b
is guaranteed to return [3.0]
. But the first run
may return either [2.0 3.0]
or [2.0 0.0]
.
Upvotes: 4