Reputation: 522
In tensorflow.
How can I select all triplets, (x, y, c
) where c > 0.5
I know this is probably a very basic question but I'm very new to Tensorflow.
Upvotes: 1
Views: 42
Reputation: 1878
Use tf.where
. For example,
x = np.random.rand(20,3)
sess = tf.Session()
print x[tf.where(tf.greater(x[:,2], 0.5)).eval(session=sess)]
Or slightly cleaner,
tf.boolean_mask(x,tf.greater(x[:,2], 0.5)).eval(session=sess)
Upvotes: 2