Reputation: 1658
I have a batch of image, [B, H, W, 3]
, and spatial N indices, [B, N, 2]
. 2 is for H and W.
Does anyone know how to get N pixels from an image? So the resulting shape is [B, N, 3]
.
It seems tf.gather_nd
helps, but I still don't figure out how to. Thank you.
https://www.tensorflow.org/api_docs/python/tf/gather_nd
Upvotes: 0
Views: 83
Reputation: 6377
Use batch_dims
parameter:
params = tf.random.uniform((5, 100, 200, 3))
inds = tf.random.uniform((5, 300, 2), 0, 100, tf.int32)
output = tf.gather_nd(params, inds, batch_dims=1)
Upvotes: 2