Animesh Karnewar
Animesh Karnewar

Reputation: 436

Tensorflow: cross index slicing of a tensor

I have two tensors of the following shape:

tensor1 => shape(10, 99, 106)
tensor2 => shape(10, 99)

The tensor2 contains values ranging from 0 - 105 That I wish to use to slice the last dimension of the tensor1 and obtain tensor3 of the shape

tensor3 => shape(10, 99, 99)

I have tried using:

tensor4 = tf.gather(tensor1, tensor2)
# this causes tensor4 to be of shape (10, 99, 99, 106)

Also, using

tensor4 = tf.gather_nd(tensor1, tensor2)
# gives the error: last dimension of tensor2 (which is 99) must be 
# less than the rank of the tensor1 (which is 3).

What I am looking for something that resembles the numpy's cross_indexing for this.

Upvotes: 1

Views: 80

Answers (1)

Lior
Lior

Reputation: 2019

You can use tf.map_fn:

 tensor3 = tf.map_fn(lambda u: tf.gather(u[0],u[1],axis=1),[tensor1,tensor2],dtype=tensor1.dtype)

You can think of this line as a loop that runs over the first dimensions of tensor1 and tensor2, and for each index i in the their first dimension it applies tf.gather on tensor1[i,:,:] and tensor2[i,:].

Upvotes: 1

Related Questions