Wt.N
Wt.N

Reputation: 1658

Tensorflow, gather from image by spatial indices

I have a batch of image, [B, H, W, 3], and spatial N indices, [B, N, 2]. 2 is for H and W.
Does anyone know how to get N pixels from an image? So the resulting shape is [B, N, 3].
It seems tf.gather_nd helps, but I still don't figure out how to. Thank you.

https://www.tensorflow.org/api_docs/python/tf/gather_nd

Upvotes: 0

Views: 83

Answers (1)

Andrey
Andrey

Reputation: 6377

Use batch_dims parameter:

params = tf.random.uniform((5, 100, 200, 3))
inds = tf.random.uniform((5, 300, 2), 0, 100, tf.int32)
output = tf.gather_nd(params, inds, batch_dims=1)

Upvotes: 2

Related Questions