ravi
ravi

Reputation: 6338

Incorporating dim parameter of torch.topk in tf.nn.top_k

Pytorch provide torch.topk(input, k, dim=None, largest=True, sorted=True) function to calculate k largest elements of the given input tensor along a given dimension dim.

I have a tensor of shape (16, 512, 4096) and I am using torch.topk in the following manner-

# inputs.shape (16L, 512L, 4096L)
dist, idx = torch.topk(inputs, 64, dim=2, largest=False, sorted=False)
# dist.shape (16L, 512L, 64L), idx.shape (16L, 512L, 64L)

I found similar tensorflow implementaion as following - tf.nn.top_k(input, k=1, sorted=True, name=None).

My question is how to Incorporate dim=2 parameter in tf.nn.top_k so as to achieve the tensor of the same shape as calculated by pytorch?

Upvotes: 2

Views: 1151

Answers (1)

xdurch0
xdurch0

Reputation: 10474

tf.nn.top_k works on the last dimension of the input. This means that it should work as is for your example:

dist, idx = tf.nn.top_k(inputs, 64, sorted=False)

In general you can imagine the Tensorflow version to work like the Pytorch version with hardcoded dim=-1, i.e. the last dimension.

However it looks like you actually want the k smallest elements. In this case we could do

dist, idx = tf.nn.top_k(-1*inputs, 64, sorted=False)
dist = -1*dist

So we take the k largest of the negative inputs, which are the k smallest of the original inputs. Then we invert the negative on the values.

Upvotes: 1

Related Questions