user8490020
user8490020

Reputation: 143

Difference between tf.assign and assignment operator (=)

I'm trying to understand the difference between tf.assign and the assignment operator(=). I have three sets of code

First, using simple tf.assign

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  assign_op = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(assign_op)
    print a.eval()
    print a.eval()

The output is expected as

2
2
2

Second, using assignment operator

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = a + 1
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
   print sess.run(a)
   print a.eval()
   print a.eval()

The results are still 2, 2, 2.

Third, I use both tf.assign and assignment operator

import tensorflow as tf

with tf.Graph().as_default():
  a = tf.Variable(1, name="a")
  a = tf.assign(a, tf.add(a,1))
  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print sess.run(a)
    print a.eval()
    print a.eval()

Now, the output becomes 2, 3, 4.

My questions are

  1. In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? Is the op run by the session really just the assign_op? But when I run a.eval(), it doesn't continue to increment a, hence it appears eval is evaluating a "static" variable.

  2. I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?

Thanks.

Upvotes: 14

Views: 7207

Answers (3)

Izana
Izana

Reputation: 3135

First, the anwser is not really precise. IMO, there's no distinguish between python object and tf object. They are all memory objects managed by python GC.

If you change second a to b, and print vars out,

In [2]: g = tf.Graph()

In [3]: with g.as_default():
   ...:     a = tf.Variable(1, name='a')
   ...:     b = a + 1
   ...:

In [4]: print(a)
<tf.Variable 'a:0' shape=() dtype=int32_ref>

In [5]: print(b)
Tensor("add:0", shape=(), dtype=int32)

In [6]: id(a)
Out[6]: 140253111576208

In [7]: id(b)
Out[7]: 140252306449616

a and b are not referring the same object in memory.

Draw the computation graph, or memory graph

first-line,

# a = tf.Varaible(...
a -> var(a)

second line,

# b = a + 1
b -> add - var(a)
      |
       \-- 1

now if you replace it back to your b = a + 1 to a = a + 1, the a after assign operation is pointing to an tf.add object instead of the variable a incremented by 1.

When you run sess.run, you are fetching the result by that add operator with no side effect to the original a variable.

tf.assign, on the other hand, will have the side effect of updating the state of the graph under the session.

Upvotes: 1

Gambitier
Gambitier

Reputation: 591

For snippet 1 :

with tf.Graph().as_default():
  a = tf.Variable(1, name="a_var")
  assign_op = tf.assign(a, tf.add(a,1,name='ADD'))

  b = tf.Variable(112)
  b = b.assign(a)  
  print(a)
  print(b)
  print(assign_op)  

  with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())    
    print (sess.run(a))      
    print ("assign_op : ",sess.run(assign_op))    
    print("         b :- ",b.eval())
    print (sess.run(a))
    print (sess.run(a))    
    print ("assign_op : ",sess.run(assign_op))
    print (sess.run(a))  
    print (sess.run(a))     
    writer = tf.summary.FileWriter("/tmp/log", sess.graph)
    writer.close()

the o/p for this snippet 1:

<tf.Variable 'a_var:0' shape=() dtype=int32_ref>
Tensor("Assign_1:0", shape=(), dtype=int32_ref)
Tensor("Assign:0", shape=(), dtype=int32_ref)
1
assign_op :  2
        b :-  2
2
2
assign_op :  3
3
3

have a look at tensorboard's computational graph

points to be noted:

  1. first variable 'a' is evaluated so you get o/p : 1
  2. next sess.run(assign_op), executes => assign_op = tf.assign(a, tf.add(a,1,name='ADD')), which has effect of updating variable 'a'(=2) and creating 'assign_op' which is tensor type of object.

For Snippet 2: see computational graph, you'll get the idea (just note there is no node for assignment operation)

with tf.Graph().as_default():
a = tf.Variable(1, name="Var_a")
just_a = a + 1  
print(a)
print(just_a)

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  print (sess.run(a))  
  print (sess.run(a))   
  print ("just_a : ",sess.run(just_a))    
  print (sess.run(a))
  print (sess.run(a))
  print ("just_a : ",sess.run(just_a)) 
  print (sess.run(a))  
  print (sess.run(a)) 
  writer = tf.summary.FileWriter("/tmp/log", sess.graph)
  writer.close()

the o/p for snippet 2:

enter code here
<tf.Variable 'Var_a:0' shape=() dtype=int32_ref>
Tensor("add:0", shape=(), dtype=int32)
1
1
just_a :  2
1
1
just_a :  2
1
1

For Snippet 3: Computational graph

with tf.Graph().as_default():
a = tf.Variable(1, name="Var_name_a")
a = tf.assign(a, tf.add(a,5))

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())    
  print (sess.run(a))  
  print (sess.run(a))
  print ("        a : ",sess.run(a))
  print (sess.run(a))
  print (sess.run(a))
  print ("         a : ",sess.run(a))
  print (sess.run(a))  
  print (sess.run(a))        
  writer = tf.summary.FileWriter("/tmp/log", sess.graph)
  writer.close()

o/p for snippet 3:

enter code here
6
11
    a :  16
21
26
     a :  31
36
41

Now if you look at computation graph for this snippet it looks similar/exact to that of snippet 1. But the catch here is the code a = tf.assign(a, tf.add(a,5)), not only update variable 'a' but also creates another tensor 'a' again.

now just recently created 'a' will be used by

print (sess.run(a))

and this 'a' will be a = tf.assign(a, tf.add(a,5))

the 'a' from tf.add(a,5) is nothing but 'a'(=1) => a = tf.Variable(1, name="Var_name_a")... so 5+1=6 is assigned to original 'a' and this original 'a' is assigned to new 'a'.

I have one more example explaining the concept all at once

check the graph here

enter code here
with tf.Graph().as_default():
  w = tf.Variable(10,name="VAR_W") #initial val = 2

  init_op = tf.global_variables_initializer()

 # Launch the graph in a session.
 with tf.Session() as sess:
    # Run the variable initializer.
    sess.run(init_op)
    print(w.eval())
    print(w) #type of 'w' before assign operation

    #CASE:1
    w = w.assign(w + 50)#adding 100 to var w
    print(w.eval())      
    print(w) #type of 'w' after assign operation

    # now if u try  =>  w = w.assign(w + 50), u will get error bcoz newly 
    created 'w' is considered here which don't have assign attribute

    #CASE:2    
    w = tf.assign(w, w + 100) #adding 100 to var w
    print(w.eval())  
    #CASE:3    
    w = tf.assign(w, w + 300) #adding 100 to var w
    print(w.eval())    
    writer = tf.summary.FileWriter("/tmp/log", sess.graph)
    writer.close()

The o/p for snippet above:

10
<tf.Variable 'VAR_W:0' shape=() dtype=int32_ref>
60
Tensor("Assign:0", shape=(), dtype=int32_ref)
210
660

Upvotes: 0

E_net4
E_net4

Reputation: 30042

The main confusion here is that doing a = a + 1 will reassign the Python variable a to the resulting tensor of the addition operation a + 1. tf.assign, on the other hand, is an operation for setting the value of a TensorFlow variable.

a = tf.Variable(1, name="a")
a = a + 1

This is equivalent to:

a = tf.add(tf.Variable(1, name="a"), 1)

With that in mind:

In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? [...]

It might look so, but not true. As explained above, this will only reassign the Python variable. And without tf.assign or any other operation that changes the variable, it stays with the value 1. Each time a is evaluated, the program will always calculate a + 1 => 1 + 1.

I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?

That's because calling eval() on the assignment tensor in the third snippet also triggers the variable assignment (note that this isn't much different from doing session.run(a) with the current session).

Upvotes: 5

Related Questions