Snehal
Snehal

Reputation: 757

How can I select top-n elements from tensor without repeating elements?

I want to select top-n elements of 3 dimension tensor given the picked elements are all unique. All the elements are sorted by the 2nd column, and I'm selecting top-2 in the example below but I don't want duplicates in there.

input_tensor = tf.constant([
  [[2.0, 1.0],
   [2.0, 1.0],
   [3.0, 0.4],
   [1.0, 0.1]],
  [[44.0, 0.8],
   [22.0, 0.7],
   [11.0, 0.5],
   [11.0, 0.5]],
  [[5555.0, 0.8],
   [3333.0, 0.7],
   [4444.0, 0.4],
   [1111.0, 0.1]],
  [[444.0, 0.8],
   [333.0, 1.1],
   [333.0, 1.1],
   [111.0, 0.1]]
])
>> TOPK = 2
>> topk_resutls = tf.gather(
    input_tensor,
    tf.math.top_k(input_tensor[:, :, 1], k=TOPK, sorted=True).indices,
    batch_dims=1
)
>> topk_resutls.numpy().tolist()
[[[2.0, 1.0], [2.0, 1.0]],
 [[44.0, 0.8], [22.0, 0.7]],
 [[5555.0, 0.8], [3333.0, 0.7]],
 [[333.0, 1.1], [333.0, 1.1]]]
[[[2.0, 1.0], [3.0, 0.4]],       # [3.0, 0.4] is the 2nd highest element based on 2nd column
 [[44.0, 0.8], [22.0, 0.7]],   
 [[5555.0, 0.8], [3333.0, 0.7]],
 [[333.0, 1.1], [444.0, 0.8]]]   # [444.0, 0.8] is the 2nd highest element based on 2nd column

Upvotes: 2

Views: 1154

Answers (1)

javidcf
javidcf

Reputation: 59701

This is one possible way to do that, although it requires more work since it sorts the array first.

import tensorflow as tf
import numpy as np

# Input data
k = 2
input_tensor = tf.constant([
  [[2.0, 1.0],
   [2.0, 1.0],
   [3.0, 0.4],
   [1.0, 0.1]],
  [[44.0, 0.8],
   [22.0, 0.7],
   [11.0, 0.5],
   [11.0, 0.5]],
  [[5555.0, 0.8],
   [3333.0, 0.7],
   [4444.0, 0.4],
   [1111.0, 0.1]],
  [[444.0, 0.8],
   [333.0, 1.1],
   [333.0, 1.1],
   [111.0, 0.1]]
])
# Sort by first column
idx = tf.argsort(input_tensor[..., 0], axis=-1)
s = tf.gather_nd(input_tensor, tf.expand_dims(idx, axis=-1), batch_dims=1)
# Find repeated elements
col1 = s[..., 0]
col1_ext = tf.concat([col1[..., :1] - 1, col1], axis=-1)
mask = tf.math.not_equal(col1_ext[..., 1:], col1_ext[..., :-1])
# Replace value for repeated elements with "minus infinity"
col2 = s[..., 1]
col2_masked = tf.where(mask, col2, col2.dtype.min)
# Get top-k results
topk_idx = tf.math.top_k(col2_masked, k=k, sorted=True).indices
topk_results = tf.gather(s, topk_idx, batch_dims=1)
# Print
with np.printoptions(suppress=True):
    print(topk_results.numpy())
# [[[   2.     1. ]
#   [   3.     0.4]]
# 
#  [[  44.     0.8]
#   [  22.     0.7]]
# 
#  [[5555.     0.8]
#   [3333.     0.7]]
# 
#  [[ 333.     1.1]
#   [ 444.     0.8]]]

Note there is a kind of corner case which is when there are not k different elements in a group. In that case, this solution would put the repeated elements at the end, which would break the score order.

Upvotes: 1

Related Questions