redb
redb

Reputation: 522

How to select group of items in a tensor

In tensorflow. How can I select all triplets, (x, y, c) where c > 0.5

enter image description here

I know this is probably a very basic question but I'm very new to Tensorflow.

Upvotes: 1

Views: 42

Answers (1)

dgumo
dgumo

Reputation: 1878

Use tf.where. For example,

x = np.random.rand(20,3)
sess = tf.Session()
print x[tf.where(tf.greater(x[:,2], 0.5)).eval(session=sess)]

Or slightly cleaner, tf.boolean_mask(x,tf.greater(x[:,2], 0.5)).eval(session=sess)

Upvotes: 2

Related Questions