Reputation: 768
As stated in the title, is there a TensorFlow equivalent of the numpy.all() function to check if all the values in a bool tensor are True
? What is the best way to implement such a check?
Upvotes: 11
Views: 5000
Reputation: 182
You could use tf.experimental.numpy.all
in tf 2.4
x = tf.constant([False, False])
tf.experimental.numpy.all(x)
Upvotes: 1
Reputation: 19634
Use tf.reduce_all, as follows:
import tensorflow as tf
a=tf.constant([True,False,True,True],dtype=tf.bool)
res=tf.reduce_all(a)
sess=tf.InteractiveSession()
res.eval()
This returns False
.
On the other hand, this returns True
:
import tensorflow as tf
a=tf.constant([True,True,True,True],dtype=tf.bool)
res=tf.reduce_all(a)
sess=tf.InteractiveSession()
res.eval()
Upvotes: 11
Reputation: 768
One way of solving this problem would be to do:
def all(bool_tensor):
bool_tensor = tf.cast(bool_tensor, tf.float32)
all_true = tf.equal(tf.reduce_mean(bool_tensor), 1.0)
return all_true
However, it's not a TensorFlow dedicated funciton. Just a workaround.
Upvotes: 0