Jo37
Jo37

Reputation: 45

Decorating with @tf.function changes if condition output

I am trying to evaluate if my variable a is empty (i.e., has size == 0). However when decorating the code with @tf.function, the if statement incorrectly evaluates as True, whereras when removing the decorator it evaluates as False. tf.size(a) seems to correctly evaluate to 0 in both cases. How to fix this ?

import tensorflow as tf
a=tf.Variable([[]])
@tf.function
def test(a):
    print_op = tf.print(tf.size(a))
    print(tf.size(a))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None
test(a)

Upvotes: 3

Views: 738

Answers (1)

Stewart_R
Stewart_R

Reputation: 14505

It's a little bit of a head-scratcher but, once we understand that tf.function is mapping python ops & control flow to a tf graph whereas the bare function is just executing eagerly, we can pick through it and it makes a lot more sense.

I have tweaked your example to illustrate what's going on. Consider test1 and test2 below:

@tf.function
def test1(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

def test2(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        print('fail')
    with tf.control_dependencies([print_op]):
        return None

these are identical to one another except for the @tf.function decorator.

Now executing test2(tf.Variable([[]])) gives us:

0
python print size: 0

which is the behaviour I assume you'd expect. Whereas test1(tf.Variable([[]])) gives:

python print size: Tensor("Size_1:0", shape=(), dtype=int32)
fail
0

There are a couple of things (beyond the fail) about this output that you might find surprising:

  • The print() statement prints out a (yet to be evaluated) tensor rather than a zero
  • The order of the print() and the tf.print() have been reversed

This is because by adding the @tf.function we no longer have a python function but instead have a tf graph mapped from the function code using autograph. This means that, at the point that the if condition is evaluated, we have not yet executed tf.math.not_equal(tf.size(a),0) and just have an object (an instance of a Tensor object) which, in python, is truthy:

class MyClass:
  pass
my_obj = MyClass()
if (my_obj):
  print ("my_obj evaluates to true") ## outputs "my_obj evaluates to true"

This means we get to the print('fail') statement in test1 before having evaluated tf.math.not_equal(tf.size(a),0).

So what's the fix?

Well, if we remove the call to the python-only print() function in the if block and replace it with an autograph-friendly tf.print() statement then autograph will seamlessly convert our if ... else ... logic to a graph friendly tf.cond statement that ensures everything happens in the correct order:

def test3(a):
    print_op = tf.print(tf.size(a))
    print("python print size: {}".format(tf.size(a)))
    if tf.math.not_equal(tf.size(a),0):
        tf.print('fail')
    with tf.control_dependencies([print_op]):
        return None
test3(tf.Variable([[]]))
0
python print size: 0

Upvotes: 2

Related Questions