Reputation: 436
I have two tensors of the following shape:
tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)
The tensor2
contains values ranging from 0 - 105
That I wish to use to slice the last dimension of the tensor1
and obtain tensor3
of the shape
tensor3 => shape(10, 99, 99)
I have tried using:
tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)
Also, using
tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be
# less than the rank of the tensor1 (which is 3).
What I am looking for something that resembles the numpy's cross_indexing for this.
Upvotes: 1
Views: 80
Reputation: 2019
You can use tf.map_fn
:
tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)
You can think of this line as a loop that runs over the first dimensions of tensor1
and tensor2
, and for each index i
in the their first dimension it applies tf.gather
on tensor1[i,:,:]
and tensor2[i,:]
.
Upvotes: 1