Reputation: 143
I'm trying to understand the difference between tf.assign and the assignment operator(=). I have three sets of code
First, using simple tf.assign
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
assign_op = tf.assign(a, tf.add(a,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(assign_op)
print a.eval()
print a.eval()
The output is expected as
2
2
2
Second, using assignment operator
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
a = a + 1
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(a)
print a.eval()
print a.eval()
The results are still 2, 2, 2.
Third, I use both tf.assign and assignment operator
import tensorflow as tf
with tf.Graph().as_default():
a = tf.Variable(1, name="a")
a = tf.assign(a, tf.add(a,1))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print sess.run(a)
print a.eval()
print a.eval()
Now, the output becomes 2, 3, 4.
My questions are
In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? Is the op run by the session really just the assign_op? But when I run a.eval(), it doesn't continue to increment a, hence it appears eval is evaluating a "static" variable.
I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?
Thanks.
Upvotes: 14
Views: 7207
Reputation: 3135
First, the anwser is not really precise. IMO, there's no distinguish between python object and tf object. They are all memory objects managed by python GC.
If you change second a
to b
, and print vars out,
In [2]: g = tf.Graph()
In [3]: with g.as_default():
...: a = tf.Variable(1, name='a')
...: b = a + 1
...:
In [4]: print(a)
<tf.Variable 'a:0' shape=() dtype=int32_ref>
In [5]: print(b)
Tensor("add:0", shape=(), dtype=int32)
In [6]: id(a)
Out[6]: 140253111576208
In [7]: id(b)
Out[7]: 140252306449616
a
and b
are not referring the same object in memory.
Draw the computation graph, or memory graph
first-line,
# a = tf.Varaible(...
a -> var(a)
second line,
# b = a + 1
b -> add - var(a)
|
\-- 1
now if you replace it back to your b = a + 1
to a = a + 1
, the a
after assign operation is pointing to an tf.add
object instead of the variable a
incremented by 1.
When you run sess.run
, you are fetching the result by that add
operator with no side effect to the original a
variable.
tf.assign
, on the other hand, will have the side effect of updating the state of the graph under the session.
Upvotes: 1
Reputation: 591
For snippet 1 :
with tf.Graph().as_default():
a = tf.Variable(1, name="a_var")
assign_op = tf.assign(a, tf.add(a,1,name='ADD'))
b = tf.Variable(112)
b = b.assign(a)
print(a)
print(b)
print(assign_op)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(a))
print ("assign_op : ",sess.run(assign_op))
print(" b :- ",b.eval())
print (sess.run(a))
print (sess.run(a))
print ("assign_op : ",sess.run(assign_op))
print (sess.run(a))
print (sess.run(a))
writer = tf.summary.FileWriter("/tmp/log", sess.graph)
writer.close()
the o/p for this snippet 1:
<tf.Variable 'a_var:0' shape=() dtype=int32_ref>
Tensor("Assign_1:0", shape=(), dtype=int32_ref)
Tensor("Assign:0", shape=(), dtype=int32_ref)
1
assign_op : 2
b :- 2
2
2
assign_op : 3
3
3
have a look at tensorboard's computational graph
points to be noted:
For Snippet 2: see computational graph, you'll get the idea (just note there is no node for assignment operation)
with tf.Graph().as_default():
a = tf.Variable(1, name="Var_a")
just_a = a + 1
print(a)
print(just_a)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(a))
print (sess.run(a))
print ("just_a : ",sess.run(just_a))
print (sess.run(a))
print (sess.run(a))
print ("just_a : ",sess.run(just_a))
print (sess.run(a))
print (sess.run(a))
writer = tf.summary.FileWriter("/tmp/log", sess.graph)
writer.close()
the o/p for snippet 2:
enter code here
<tf.Variable 'Var_a:0' shape=() dtype=int32_ref>
Tensor("add:0", shape=(), dtype=int32)
1
1
just_a : 2
1
1
just_a : 2
1
1
For Snippet 3: Computational graph
with tf.Graph().as_default():
a = tf.Variable(1, name="Var_name_a")
a = tf.assign(a, tf.add(a,5))
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print (sess.run(a))
print (sess.run(a))
print (" a : ",sess.run(a))
print (sess.run(a))
print (sess.run(a))
print (" a : ",sess.run(a))
print (sess.run(a))
print (sess.run(a))
writer = tf.summary.FileWriter("/tmp/log", sess.graph)
writer.close()
o/p for snippet 3:
enter code here
6
11
a : 16
21
26
a : 31
36
41
Now if you look at computation graph for this snippet it looks similar/exact to that of snippet 1. But the catch here is the code a = tf.assign(a, tf.add(a,5)), not only update variable 'a' but also creates another tensor 'a' again.
now just recently created 'a' will be used by
print (sess.run(a))
and this 'a' will be a = tf.assign(a, tf.add(a,5))
the 'a' from tf.add(a,5) is nothing but 'a'(=1) => a = tf.Variable(1, name="Var_name_a")... so 5+1=6 is assigned to original 'a' and this original 'a' is assigned to new 'a'.
I have one more example explaining the concept all at once
enter code here
with tf.Graph().as_default():
w = tf.Variable(10,name="VAR_W") #initial val = 2
init_op = tf.global_variables_initializer()
# Launch the graph in a session.
with tf.Session() as sess:
# Run the variable initializer.
sess.run(init_op)
print(w.eval())
print(w) #type of 'w' before assign operation
#CASE:1
w = w.assign(w + 50)#adding 100 to var w
print(w.eval())
print(w) #type of 'w' after assign operation
# now if u try => w = w.assign(w + 50), u will get error bcoz newly
created 'w' is considered here which don't have assign attribute
#CASE:2
w = tf.assign(w, w + 100) #adding 100 to var w
print(w.eval())
#CASE:3
w = tf.assign(w, w + 300) #adding 100 to var w
print(w.eval())
writer = tf.summary.FileWriter("/tmp/log", sess.graph)
writer.close()
The o/p for snippet above:
10
<tf.Variable 'VAR_W:0' shape=() dtype=int32_ref>
60
Tensor("Assign:0", shape=(), dtype=int32_ref)
210
660
Upvotes: 0
Reputation: 30042
The main confusion here is that doing a = a + 1
will reassign the Python variable a
to the resulting tensor of the addition operation a + 1
. tf.assign
, on the other hand, is an operation for setting the value of a TensorFlow variable.
a = tf.Variable(1, name="a")
a = a + 1
This is equivalent to:
a = tf.add(tf.Variable(1, name="a"), 1)
With that in mind:
In the 2nd snippet using (=), when I have sess.run(a), it seems I'm running an assign op. So does "a = a+1" internally create an assignment op like assign_op = tf.assign(a, a+1)? [...]
It might look so, but not true. As explained above, this will only reassign the Python variable. And without tf.assign
or any other operation that changes the variable, it stays with the value 1. Each time a
is evaluated, the program will always calculate a + 1 => 1 + 1
.
I'm not sure how to explain the 3rd snippet. Why the two evals increment a, but the two evals in the 2nd snippet doesn't?
That's because calling eval()
on the assignment tensor in the third snippet also triggers the variable assignment (note that this isn't much different from doing session.run(a)
with the current session).
Upvotes: 5