stark
stark

Reputation: 97

TensorFlow: Initializing variables multiple times

I am a bit confused by how the following code segment runs.

import tensorflow as tf

x = tf.Variable(0)
init_op = tf.initialize_all_variables()
modify_op = x.assign(5)

with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(x))
    x += 3
    print(sess.run(x))
    sess.run(init_op) # Trying to initialize x once again to 0
    print(sess.run(x)) # Gives out 3, which leaves me confused.
    print(sess.run(modify_op))
    print(sess.run(x)) # Gives out 8, even more confusing

This is the output:
0
3
3
5
8

Is it that the line x += 3 is not part of the default graph? Or something else is going on? Some help will be appreciated, thanks!

Upvotes: 2

Views: 2159

Answers (1)

Neil Slater
Neil Slater

Reputation: 27227

Your x variable is being changed by

x += 3

but not in a way you might expect. The tensorflow library code over-rides +, so that you are effectively swapping the contents x for a new TF tensor (the old one will still be in the graph, just x now points to a new one). Write it out like this:

x = tf.Variable(0) + 3

and it is clearer what is going on. Also, insert some print statements . . .

x = tf.Variable(0)
print(x)
# <tensorflow.python.ops.variables.Variable object at 0x1018f5d68>

x += 3
print(x)
# Tensor("add:0", shape=(), dtype=int32)

If the contents of x are important to you, then avoid re-assigning to x if you want to track/display x using the variable name later. Alternatively, you can always name the tensor and fetch it direct from the graph if you don't have a convenient Python variable pointing at it. What's important is getting used to the separation between TF variables and Python variables.

Actually seeing the TF variable being assigned and re-set as you are trying to do, needs to use a TF assignment operator:

import tensorflow as tf

x = tf.Variable( 0 )

with tf.Session() as session:
    session.run( tf.initialize_all_variables() )
    print( x.eval() )

    session.run( x.assign( x + 3 ) )
    print( x.eval() )

    session.run( tf.initialize_all_variables() )
    print( x.eval() )

This outputs:

0
3
0

as you were expecting.

Upvotes: 3

Related Questions