Reputation: 45
I am trying to evaluate if my variable a is empty (i.e., has size == 0
). However when decorating the code with @tf.function
, the if statement incorrectly evaluates as True, whereras when removing the decorator it evaluates as False. tf.size(a)
seems to correctly evaluate to 0 in both cases. How to fix this ?
import tensorflow as tf
a=tf.Variable([[]])
@tf.function
def test(a):
print_op = tf.print(tf.size(a))
print(tf.size(a))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
test(a)
Upvotes: 3
Views: 738
Reputation: 14505
It's a little bit of a head-scratcher but, once we understand that tf.function
is mapping python ops & control flow to a tf graph whereas the bare function is just executing eagerly, we can pick through it and it makes a lot more sense.
I have tweaked your example to illustrate what's going on. Consider test1
and test2
below:
@tf.function
def test1(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
def test2(a):
print_op = tf.print(tf.size(a))
print("python print size: {}".format(tf.size(a)))
if tf.math.not_equal(tf.size(a),0):
print('fail')
with tf.control_dependencies([print_op]):
return None
these are identical to one another except for the @tf.function
decorator.
Now executing test2(tf.Variable([[]]))
gives us:
0
python print size: 0
which is the behaviour I assume you'd expect. Whereas test1(tf.Variable([[]]))
gives:
python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0
There are a couple of things (beyond the fail
) about this output that you might find surprising:
print()
statement prints out a (yet to be evaluated) tensor rather than a zeroprint()
and the tf.print()
have been reversedThis is because by adding the @tf.function
we no longer have a python function but instead have a tf graph mapped from the function code using autograph. This means that, at the point that the if
condition is evaluated, we have not yet executed tf.math.not_equal(tf.size(a),0)
and just have an object (an instance of a Tensor
object) which, in python, is truthy:
class MyClass:
pass
my_obj = MyClass()
if (my_obj):
print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"
This means we get to the print('fail')
statement in test1
before having evaluated tf.math.not_equal(tf.size(a),0)
.
So what's the fix?
Well, if we remove the call to the python-only print()
function in the if
block and replace it with an autograph-friendly tf.print()
statement then autograph will seamlessly convert our if ... else ...
logic to a graph friendly tf.cond
statement that ensures everything happens in the correct order:
def test3(a): print_op = tf.print(tf.size(a)) print("python print size: {}".format(tf.size(a))) if tf.math.not_equal(tf.size(a),0): tf.print('fail') with tf.control_dependencies([print_op]): return None
test3(tf.Variable([[]]))
0
python print size: 0
Upvotes: 2