Tsuan
Tsuan

Reputation: 215

What is a local variable in tensorflow?

Tensorflow has this API defined:

tf.local_variables()

Returns all variables created with collection=[LOCAL_VARIABLES].

Returns:

A list of local Variable objects.

What exactly is a local variable in TensorFlow? Can someone give me an example?

Upvotes: 19

Views: 12639

Answers (3)

Munish
Munish

Reputation: 988

I think, here understanding of TensorFlow collections is required.

TensorFlow provides collections, which are named lists of tensors or other objects, such as tf.Variable instances.

Following are in-build collections:

tf.GraphKeys.GLOBAL_VARIABLES               #=> 'variables'                                                                                                                                                                                 
tf.GraphKeys.LOCAL_VARIABLES                #=> 'local_variables'                                                                                                                                                                           
tf.GraphKeys.MODEL_VARIABLES                #=> 'model_variables'                                                                                                                                                                           
tf.GraphKeys.TRAINABLE_VARIABLES            #=> 'trainable_variables' 

In general, at the time of creation of a variable, it can be added to given collection by explicitly passing that collection as one of the collections passed to collections argument.

Theoretically, a variable can be in any combination of in-built or custom collections. But, in-build collections are used for particular purposes:

  • tf.GraphKeys.GLOBAL_VARIABLES:
    • The Variable() constructor or get_variable() automatically adds new variables to the graph collection GraphKeys.GLOBAL_VARIABLES, unless the collections argument is explicitly passed and doesn't include GLOBAL_VARIABLE.
    • By convention, these variables are shared across distributed environments (model variables are subset of these).
    • See tf.global_variables() for more details.
  • tf.GraphKeys.TRAINABLE_VARIABLES:
    • When passed trainable=True (which is default behavior), the Variable() constructor and get_variable() automatically adds new variables to this graph collection. But of course, you can use collections argument to add a variab le to any desired collection.
    • By convention, these are the variables which will be trained by an optimizer.
    • See tf.trainable_variables() for more details.
  • tf.GraphKeys.LOCAL_VARIABLES:
    • You can use tf.contrib.framework.local_variable() to add to this collection. But of course, you can use collections argument to add a variable to any desired collection.
    • By convention, these are the variables that are local to each machine. They are per process variables, usually not saved/restored to checkpoint and used for temporary or intermediate values. For example, they can be used as counters for metrics computation or number of epochs this machine has read data.
    • See tf.local_variables() for more details.
  • tf.GraphKeys.MODEL_VARIABLES:
    • You can use tf.contrib.framework.model_variable() to add to this collection. But of course, you can use collections argument to add a variable to any desired collection.
    • By convention, these are the variables that are used in the model for inference (feed forward).
    • See tf.model_variables() for more details.

You can also use your own collections. Any string is a valid collection name, and there is no need to explicitly create a collection. To add a variable (or any other object) to a collection after creating the variable, call tf.add_to_collection().

For example,

tf.__version__                                                            #=> '1.9.0'                                                                                                                                                       

# initializing using a Tensor                                                                                                                                                                                                               
my_variable01 = tf.get_variable("var01", dtype=tf.int32, initializer=tf.constant([23, 42]))                                                                                                                                                 
# initializing using a convenient initializer                                                                                                                                                                                               
my_variable02 = tf.get_variable("var02", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.zeros_initializer)                                                                                                                                 

my_variable03 = tf.get_variable("var03", dtype=tf.int32, initializer=tf.constant([1, 2]), trainable=None)                                                                                                                                   
my_variable04 = tf.get_variable("var04", dtype=tf.int32, initializer=tf.constant([3, 4]), trainable=False)                                                                                                                                  
my_variable05 = tf.get_variable("var05", shape=[1, 2, 3], dtype=tf.int32, initializer=tf.ones_initializer, trainable=True)                                                                                                                  

my_variable06 = tf.get_variable("var06", dtype=tf.int32, initializer=tf.constant([5, 6]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=None)                                                                                       
my_variable07 = tf.get_variable("var07", dtype=tf.int32, initializer=tf.constant([7, 8]), collections=[tf.GraphKeys.LOCAL_VARIABLES], trainable=True)                                                                                       

my_variable08 = tf.get_variable("var08", dtype=tf.int32, initializer=tf.constant(1), collections=[tf.GraphKeys.MODEL_VARIABLES], trainable=None)                                                                                            
my_variable09 = tf.get_variable("var09", dtype=tf.int32, initializer=tf.constant(2), collections=[tf.GraphKeys.GLOBAL_VARIABLES, tf.GraphKeys.LOCAL_VARIABLES, tf.GraphKeys.MODEL_VARIABLES, tf.GraphKeys.TRAINABLE_VARIABLES, "my_collectio
n"])                                                                                                                                                                                                                                        
my_variable10 = tf.get_variable("var10", dtype=tf.int32, initializer=tf.constant(3), collections=["my_collection"], trainable=True)                                                                                                         

[var.name for var in tf.global_variables()]                               #=> ['var01:0', 'var02:0', 'var03:0', 'var04:0', 'var05:0', 'var09:0']                                                                                            
[var.name for var in tf.local_variables()]                                #=> ['var06:0', 'var07:0', 'var09:0']                                                                                                                             
[var.name for var in tf.trainable_variables()]                            #=> ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']                                                                                            
[var.name for var in tf.model_variables()]                                #=> ['var08:0', 'var09:0']                                                                                                                                        
[var.name for var in tf.get_collection("trainable_variables")]            #=> ['var01:0', 'var02:0', 'var05:0', 'var07:0', 'var09:0', 'var10:0']                                                                                            
[var.name for var in tf.get_collection("my_collection")]                  #=> ['var09:0', 'var10:0']                                                                                                                                        

Upvotes: 11

Salvador Dali
Salvador Dali

Reputation: 222491

Short answer: a local variable in TF is any variable which was created with collections=[tf.GraphKeys.LOCAL_VARIABLES]. For example:

e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])

LOCAL_VARIABLES: the subset of Variable objects that are local to each machine. Usually used for temporarily variables, like counters. Note: use tf.contrib.framework.local_variable to add to this collection.

They are usually not saved/restored to checkpoint and used for temporary or intermediate values.


Long answer: this was a source of confusion for me as well. In the beginning I thought that local variables mean the same thing as local variable in almost any programming language, but it is not the same thing:

import tensorflow as tf

def some_func():
    z = tf.Variable(1, name='var_z')

a = tf.Variable(1, name='var_a')
b = tf.get_variable('var_b', 2)
with tf.name_scope('aaa'):
    c = tf.Variable(3, name='var_c')

with tf.variable_scope('bbb'):
    d = tf.Variable(3, name='var_d')

some_func()
some_func()

print [str(i.name) for i in tf.global_variables()]
print [str(i.name) for i in tf.local_variables()]

No matter what I tried, I always recieved only global variables:

['var_a:0', 'var_b:0', 'aaa/var_c:0', 'bbb/var_d:0', 'var_z:0', 'var_z_1:0']
[]

The documentation for tf.local_variables have not provided a lot of details:

Local variables - per process variables, usually not saved/restored to checkpoint and used for temporary or intermediate values. For example, they can be used as counters for metrics computation or number of epochs this machine has read data. The local_variable() automatically adds new variable to GraphKeys.LOCAL_VARIABLES. This convenience function returns the contents of that collection.


But reading docs for the init method in tf.Variable class, I found that while creating a variable, you can provide what kind of a variable do you want it to be by assigning a list of collections.

The list of possible collection elements is here. So to create a local variable you need to do something like this. You will see it in the list of local_variables:

e = tf.Variable(6, name='var_e', collections=[tf.GraphKeys.LOCAL_VARIABLES])
print [str(i.name) for i in tf.local_variables()]

Upvotes: 27

Yaroslav Bulatov
Yaroslav Bulatov

Reputation: 57893

It's the same as regular variable, but it's in a different collection than default (GraphKeys.VARIABLES). That collection is used by saver to initialize the default list of variables to save, so having a local designation has an effect of not saving that variable by default.

I'm seeing only one place that uses it in the codebase, which is the limit_epochs

  with ops.name_scope(name, "limit_epochs", [tensor]) as name:
    zero64 = constant_op.constant(0, dtype=dtypes.int64)
    epochs = variables.Variable(
        zero64, name="epochs", trainable=False,
        collections=[ops.GraphKeys.LOCAL_VARIABLES])

Upvotes: 19

Related Questions