Reputation: 2240
I have a CNN output tensor X
of shape (N,256,256,5), where N
is the batch dimension. I have tensors x
and y
containing N indices (each 0 to 255). I'd like to use these indices to form a (N,5) tensor Y
such that Y[n,:] = X[n, x[n], y[n], :]
. How can this be done?
Upvotes: 1
Views: 257
Reputation: 2240
This way works using tf.gather (vs. tf.gather_nd as in the original question):
tf.gather(tf.gather(X, x, batch_dims=1), y, batch_dims=1)
Upvotes: 0
Reputation: 26708
I think something similar to this could do the trick for you (if I understood your question correctly):
Your data:
import tensorflow as tf
import numpy as np
batch_size = 5
D=2
data = tf.constant(np.array(range(batch_size * D * D * 5)).reshape([batch_size, D, D, 5]))
Calculate indices:
batches = tf.reshape(tf.range(batch_size, dtype=tf.int32), shape=[batch_size, 1])
random_x = tf.random.uniform([batch_size, 1], minval = 0, maxval = D, dtype = tf.int32)
random_y = tf.random.uniform([batch_size, 1], minval = 0, maxval = D, dtype = tf.int32)
indices = tf.concat([batches, random_x, random_y], axis=1)
Note that random_x
and random_y
can be replaced by your existing x
and y
tensors. Use the tf.gather_nd
function then to apply your indices
to your tensor data
:
output = tf.gather_nd(data, indices)
print(batches, 'batches')
print(random_x, 'random_x')
print(random_y, 'random_y')
print(indices, 'indices')
print('Original tensor \n', data, '\n')
print('Updated tensor \n', output)
'''
tf.Tensor(
[[0]
[1]
[2]
[3]
[4]], shape=(5, 1), dtype=int32) batches
tf.Tensor(
[[0]
[1]
[1]
[0]
[1]], shape=(5, 1), dtype=int32) random_x
tf.Tensor(
[[0]
[1]
[0]
[0]
[0]], shape=(5, 1), dtype=int32) random_y
tf.Tensor(
[[0 0 0]
[1 1 1]
[2 1 0]
[3 0 0]
[4 1 0]], shape=(5, 3), dtype=int32) indices
Original tensor
tf.Tensor(
[[[[ 0 1 2 3 4]
[ 5 6 7 8 9]]
[[10 11 12 13 14]
[15 16 17 18 19]]]
[[[20 21 22 23 24]
[25 26 27 28 29]]
[[30 31 32 33 34]
[35 36 37 38 39]]]
[[[40 41 42 43 44]
[45 46 47 48 49]]
[[50 51 52 53 54]
[55 56 57 58 59]]]
[[[60 61 62 63 64]
[65 66 67 68 69]]
[[70 71 72 73 74]
[75 76 77 78 79]]]
[[[80 81 82 83 84]
[85 86 87 88 89]]
[[90 91 92 93 94]
[95 96 97 98 99]]]], shape=(5, 2, 2, 5), dtype=int32)
Updated tensor
tf.Tensor(
[[ 0 1 2 3 4]
[35 36 37 38 39]
[50 51 52 53 54]
[60 61 62 63 64]
[90 91 92 93 94]], shape=(5, 5), dtype=int32)
'''
The tensor output
has a shape of (batch_size, 5)
. As I said, I am not sure if I understood the question, so feel free to give some feedback.
Upvotes: 1