Reputation: 399
I am using Tensorflow 1.14.0 and trying to write a very simple function that includes conditional statements for Tensorflow. The regular (non-Tenslorflow) version of it is:
def u(x):
if x<7:
y=x+x
else:
y=x**2
return y
It seems that I cannot use this directly on Tensforflow if I do so with a code like this:
x=tf.Variable(3,name='x')
sess=tf.Session()
sess.run(x.initializer)
result=sess.run(u(x))
I will get an error like this:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-26-789531cde07a> in <module>
2 sess=tf.Session()
3 sess.run(x.initializer)
----> 4 result=sess.run(u(x))
5 # print(result)
<ipython-input-23-39f85f34465a> in uu(x)
2
3 def u(x):
----> 4 if x<7:
5 y=x+x
6 else:
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\framework\ops.py in __bool__(self)
688 `TypeError`.
689 """
--> 690 raise TypeError("Using a `tf.Tensor` as a Python `bool` is not allowed. "
691 "Use `if t is not None:` instead of `if t:` to test if a "
692 "tensor is defined, and use TensorFlow ops such as "
TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.
Following the error message, I use tf.cond instead and rewrite u(x) function:
def u(x):
import tensorflow as tf
y=tf.cond(x < 7, lambda: tf.add(x, x), lambda: tf.square(x))
return y
Then, I will get the following error:
InvalidArgumentError Traceback (most recent call last)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1355 try:
-> 1356 return fn(*args)
1357 except errors.OpError as e:
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1340 return self._call_tf_sessionrun(
-> 1341 options, feed_dict, fetch_list, target_list, run_metadata)
1342
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
1428 self._session, options, feed_dict, fetch_list, target_list,
-> 1429 run_metadata)
1430
InvalidArgumentError: Retval[0] does not have value
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-27-06e1605182c1> in <module>
2 sess=tf.Session()
3 sess.run(x.initializer)
----> 4 result=sess.run(u(x))
5 # print(result)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in run(self, fetches, feed_dict, options, run_metadata)
948 try:
949 result = self._run(None, fetches, feed_dict, options_ptr,
--> 950 run_metadata_ptr)
951 if run_metadata:
952 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1171 if final_fetches or final_targets or (handle and feed_dict_tensor):
1172 results = self._do_run(handle, final_targets, final_fetches,
-> 1173 feed_dict_tensor, options, run_metadata)
1174 else:
1175 results = []
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1348 if handle is None:
1349 return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1350 run_metadata)
1351 else:
1352 return self._do_call(_prun_fn, handle, feeds, fetches)
~\AppData\Roaming\Python\Python37\site-packages\tensorflow\python\client\session.py in _do_call(self, fn, *args)
1368 pass
1369 message = error_interpolation.interpolate(message, self._graph)
-> 1370 raise type(e)(node_def, op, message)
1371
1372 def _extend_graph(self):
InvalidArgumentError: Retval[0] does not have value
I am so confused. Can you please help?
Upvotes: 0
Views: 623
Reputation: 11333
This is a bug in TF (Related Github issue: Here). For example, the following scenarios work
tf.Variable
to tf.constant
x=tf.constant(3,name='x')
def u(x):
y=tf.cond(x < 7, lambda: tf.add(x, x), lambda: tf.square(x))
return y
tf.add
to tf.multiply
def u(x):
y=tf.cond(x < 7, lambda: tf.multiply(x, x), lambda: tf.square(x))
return y
def u(x):
y=tf.cond(x < 7, lambda: tf.add(x, 2), lambda: tf.square(x))
return y
tf.identity
def u(x):
y=tf.cond(x < 7, lambda: tf.math.add(tf.identity(x), tf.identity(x)), lambda: tf.square(x))
return y
But tf.add(x,x)
or x+x
fails. The cause of this is that tf.add
has trouble working with tf.Variable
types but works fine for tf.Tensor
types. I have a hunch that some insight can be found in the source code. Will update as I find anything.
1.15
)You need to enable version 2 of tf.cond
which apparently has this issue fixed. You can do this as follows. Unfortunately this does not work for 1.14
.
%env TF_ENABLE_COND_V2='1'
os.environ['TF_ENABLE_COND_V2'] = '1'
And that sould give you the desired result.
Upvotes: 1