jackaraz
jackaraz

Reputation: 291

`tf.case` and `tf.cond` executes all the functions within in TensorFlow

I'm trying to execute some condition-dependent functions where each function needs to contract tensors differently depending on their shapes, for instance. However, I realised that tf.cond and tf.case is executing all functions regardless of the condition. Prepared the following code as an example;

def a(): 
    print("a")
    return tf.constant(2)
def b(): 
    print("b")
    return tf.constant(3)
def c(): 
    print("c")
    return tf.constant(4)
def d(): 
    print("default")
    return tf.constant(1)

x = tf.constant(1)

@tf.function
def f():
    return tf.case([
        (tf.equal(x,1), a),
        (tf.equal(x,2), b),
        (tf.equal(x,2), c)
    ], default=d, exclusive=True)

@tf.function
def f1():
    def cond3():
        return tf.cond(tf.equal(x,2), c, d)
    def cond2():
        return tf.cond(tf.equal(x,2), b, cond3)
    
    return tf.cond(tf.equal(x,1), a,  cond2)

print(f())
print(f1())

# Output:
# a
# b
# c
# default
# tf.Tensor(2, shape=(), dtype=int32)
# a
# b
# c
# default
# tf.Tensor(2, shape=(), dtype=int32)

as you can see for both of the cases, the result is as expected but each function is executed while reaching the conclusion. Hence in my particular case, since I'm doing different calculations depending on the tensor's shape, I get a multitude of errors. I've seen many such bug reports but haven't found a solution. Is there another way to do conditional execution that I'm not aware of where different functions can be executed depending on the condition? Note that I tried simply using if tf.equal(x,2): ... but in that case, I'm getting an error saying that tensor output can not be used as python boolean. Note that this example is much-simplified version of my problem, my conditions are based on tensor shapes such as tf.equal(tf.size(tensor), N) so I really need a way to execute different things for different cases.


After @LaplaceRicky 's answer I realised that the code that I provided was not representative enough so I'm providing a better example showing what I need to do;

x = tf.ones((3,2,1))
y = tf.ones((1,2,3))
z = tf.ones((4,3,5))
k = tf.ones((3,5,5))

def a(t): 
    def exe():
        return tf.einsum("ijk,lmi", t, y)
    return exe

def b(t): 
    def exe():
        return tf.einsum("ijk,ljm", t, z)
    return exe

def d(t): 
    def exe():
        return tf.einsum("ijk,klm", t, z)
    return exe

c = tf.constant(1)

@tf.function
def f(t):
    y = tf.case([
        (tf.equal(tf.shape(t)[0], 3), a(t)),
        (tf.equal(tf.shape(t)[1], 3), b(t)),
    ], default=d, exclusive=True)
    return y



print(f(x))

This function will execute properly without tf.function decorator leading to

tf.Tensor(
[[[[3. 3.]]]
 [[[3. 3.]]]], shape=(2, 1, 1, 2), dtype=float32

However, when the decorator is included I got a ValueError which shows that all the cases are executed.

System information

Upvotes: 2

Views: 1350

Answers (1)

Laplace Ricky
Laplace Ricky

Reputation: 1687

Short answer: use tf.print instead of print to check whether a particular branch is really being executed in tensorflow graph mode.

Explanations: print does not work and won't print in graph mode but it will print during tracing. The printed messages actually implies all of the branches were added to the tensorflow graph but it does not imply all branches will be executed all the time in graph mode. tf.print should be used instead for the debugging.

For more information: https://www.tensorflow.org/guide/function#conditionals

Demonstration:

def a():
  tf.print('a')
  return tf.constant(10)

def b():
  tf.print('b')
  return tf.constant(11)

def c():
  tf.print('c')
  return tf.constant(12)


@tf.function
def cond_fn(x):
  return tf.switch_case(x, {0:a,1:b}, default=c)

print(cond_fn(tf.constant(0)))
print(cond_fn(tf.constant(1)))
print(cond_fn(tf.constant(2)))

Expected outputs:

a
tf.Tensor(10, shape=(), dtype=int32)
b
tf.Tensor(11, shape=(), dtype=int32)
c
tf.Tensor(12, shape=(), dtype=int32)

The ValueError error message is because tensorflow graph does not support this kind of feature very well, at least not with tf.einsum. One way of the workarounds is to have a graph that supports variable-shaped inputs by using tf.function(f).get_concrete_function(tf.TensorSpec(shape=[None,None,None])).

Besides, tf.einsum is problematic in the process and have to be replaced by tf.transpose and tf.tensordot.

Example Codes:

x = tf.random.normal((3,2,1))
y = tf.random.normal((1,2,3))
z = tf.random.normal((4,3,5))
k = tf.random.normal((3,5,5))

#for checking the values
def f2(t):
    p = tf.case([
        (tf.equal(tf.shape(t)[0], 3), lambda:tf.einsum("ijk,lmi", t, y)),
        (tf.equal(tf.shape(t)[1], 3), lambda:tf.einsum("ijk,ljm", t, z)),
    ], default=lambda:tf.einsum("ijk,klm", t, k), exclusive=True)
    return p

#work around
def f(t):
    if tf.shape(t)[0] == 3:
      tf.print('branch a executed')
      return tf.tensordot(tf.transpose(t,[1,2,0]), tf.transpose(y,[2,0,1]),1)
    elif tf.shape(t)[1] == 3:
      tf.print('branch b executed')
      return tf.tensordot(tf.transpose(t,[0,2,1]), tf.transpose(z,[1,0,2]),1)
    else:
      tf.print('branch c executed')
      return tf.tensordot(t, k,1)

graph_f=tf.function(f).get_concrete_function(tf.TensorSpec(shape=[None,None,None]))

print(np.allclose(graph_f(x),f2(x)))
print(np.allclose(graph_f(y),f2(y)))
print(np.allclose(graph_f(z),f2(z)))

Expected outputs:

branch a executed
True
branch c executed
True
branch b executed
True

Upvotes: 1

Related Questions