Albert
Albert

Reputation: 399

How to make conditional statements in Tensorflow

I am using Tensorflow 1.14.0 and trying to write a very simple function that includes conditional statements for Tensorflow. The regular (non-Tenslorflow) version of it is:

def u(x):
    if x<7:
        y=x+x
    else:
        y=x**2
    return y

It seems that I cannot use this directly on Tensforflow if I do so with a code like this:

x=tf.Variable(3,name='x')
sess=tf.Session()
sess.run(x.initializer)
result=sess.run(u(x))

I will get an error like this:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-26-789531cde07a> in <module>
      2 sess=tf.Session()
      3 sess.run(x.initializer)
----> 4 result=sess.run(u(x))
      5 # print(result)

<ipython-input-23-39f85f34465a> in uu(x)
      2 
      3 def u(x):
----> 4     if x<7:
      5         y=x+x
      6     else:

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py in __bool__(self)
    688       `TypeError`.
    689     """
--> 690     raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
    691                     "Use `if t is not None:` instead of `if t:` to test if a "
    692                     "tensor is defined, and use TensorFlow ops such as "

TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

Following the error message, I use tf.cond instead and rewrite u(x) function:

def u(x):
    import tensorflow as tf
    y=tf.cond(x < 7, lambda: tf.add(x, x), lambda: tf.square(x))
    return y 

Then, I will get the following error:

InvalidArgumentError                      Traceback (most recent call last)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1355     try:
-> 1356       return fn(*args)
   1357     except errors.OpError as e:

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
   1340       return self._call_tf_sessionrun(
-> 1341           options, feed_dict, fetch_list, target_list, run_metadata)
   1342 

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
   1428         self._session, options, feed_dict, fetch_list, target_list,
-> 1429         run_metadata)
   1430 

InvalidArgumentError: Retval[0] does not have value

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-27-06e1605182c1> in <module>
      2 sess=tf.Session()
      3 sess.run(x.initializer)
----> 4 result=sess.run(u(x))
      5 # print(result)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
    948     try:
    949       result = self._run(None, fetches, feed_dict, options_ptr,
--> 950                          run_metadata_ptr)
    951       if run_metadata:
    952         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1171     if final_fetches or final_targets or (handle and feed_dict_tensor):
   1172       results = self._do_run(handle, final_targets, final_fetches,
-> 1173                              feed_dict_tensor, options, run_metadata)
   1174     else:
   1175       results = []

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
   1348     if handle is None:
   1349       return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1350                            run_metadata)
   1351     else:
   1352       return self._do_call(_prun_fn, handle, feeds, fetches)

~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
   1368           pass
   1369       message = error_interpolation.interpolate(message, self._graph)
-> 1370       raise type(e)(node_def, op, message)
   1371 
   1372   def _extend_graph(self):

InvalidArgumentError: Retval[0] does not have value

I am so confused. Can you please help?

Upvotes: 0

Views: 623

Answers (1)

thushv89
thushv89

Reputation: 11333

This is a bug in TF (Related Github issue: Here). For example, the following scenarios work

What works

Changing tf.Variable to tf.constant

x=tf.constant(3,name='x')
def u(x):    
    y=tf.cond(x < 7, lambda: tf.add(x, x), lambda: tf.square(x))
    return y

Changing tf.add to tf.multiply

def u(x):    
    y=tf.cond(x < 7, lambda: tf.multiply(x, x), lambda: tf.square(x))
    return y

Adding a constant

def u(x):    
    y=tf.cond(x < 7, lambda: tf.add(x, 2), lambda: tf.square(x))
    return y 

Using tf.identity

def u(x):    
    y=tf.cond(x < 7, lambda: tf.math.add(tf.identity(x), tf.identity(x)), lambda: tf.square(x))
    return y 

What doesn't work

But tf.add(x,x) or x+x fails. The cause of this is that tf.add has trouble working with tf.Variable types but works fine for tf.Tensor types. I have a hunch that some insight can be found in the source code. Will update as I find anything.

Solution (For TF 1.15)

You need to enable version 2 of tf.cond which apparently has this issue fixed. You can do this as follows. Unfortunately this does not work for 1.14.

Using magic on Jupyter

%env TF_ENABLE_COND_V2='1'

Using Python

os.environ['TF_ENABLE_COND_V2'] = '1'

And that sould give you the desired result.

Upvotes: 1

Related Questions