Reputation: 3324
I am trying to train a network with tensorflow with multiple towers. I had set reuse = True
for all the towers. But in the cifar10 multi gpu train of tensorflow tutorials, the reuse variable has set after the first tower was created:
with tf.variable_scope(tf.get_variable_scope()):
for i in xrange(FLAGS.num_gpus):
with tf.device('/gpu:%d' % i):
with tf.name_scope('%s_%d' % (cifar10.TOWER_NAME, i)) as scope:
# Dequeues one batch for the GPU
image_batch, label_batch = batch_queue.dequeue()
# Calculate the loss for one tower of the CIFAR model. This function
# constructs the entire CIFAR model but shares the variables across
# all towers.
# Actually the logits (whole network) is defined in tower_loss
loss = tower_loss(scope, image_batch, label_batch)
# Reuse variables for the next tower.
tf.get_variable_scope().reuse_variables()
Does it make any difference? What happens if we set reuse=True beforehand?
Upvotes: 0
Views: 912
Reputation: 356
There are two ways to share variables.
Either version 1:
with tf.variable_scope("model"):
output1 = my_image_filter(input1)
with tf.variable_scope("model", reuse=True):
output2 = my_image_filter(input2)
or version 2:
with tf.variable_scope("model") as scope:
output1 = my_image_filter(input1)
scope.reuse_variables()
output2 = my_image_filter(input2)
Both methods share the variable. The second method is used in the Cifar10 tutorial because it is much cleaner (and that's only my opinion). You can try to rebuild it with version 1, the code will probably be less readable.
Upvotes: 1
Reputation: 11968
You need to have reuse=False
for the first run to generate variables. It is an error if reuse=True but the variable is not yet constructed.
If you use a newer version of tensorflow (>1.4 I think) you can use reuse=tf.AUTO_REUSE
and it will do the magic for you.
I'm not sure how this interacts with the multi device setup you have. Double check if the variable names don't become prefixed by the device. In that case there's no reuse, each device has a different variable.
Upvotes: 2