Alex Trevithick
Alex Trevithick

Reputation: 791

How to sort one tensor based on the order of another in TensorFlow?

I want to sort a tensor based on the sorted order of another tensor which shares the same shape except for the last axis. For example, if I have some meaningful values associated with feature vectors, then I would like to sort the feature vectors based upon these values. See the example:

# Meaningful values with shape (2,3)
z_vals = tf.constant([[3,1,2],[-2,-1,0]])

# I want to have these features sorted based on the ascending order of the above
features = tf.broadcast_to(z_vals[...,None], (2,3,10))

# Desired result

array([[[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3]],

       [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]], dtype=int32)>

Thus, I want to have features sorted such that it follows the sorted order of z_vals.

Upvotes: 2

Views: 1480

Answers (1)

Alex Trevithick
Alex Trevithick

Reputation: 791

I think I have a solution that works, but I'm not sure if there is a much simpler way to do this.

What I do is:

# Meaningful values with shape (2,3)
z_vals = tf.constant([[3,1,2],[-2,-1,0]])
features = tf.broadcast_to(z_vals[...,None], (2,3,10))

# Get sorted indices along last axis
inds = tf.argsort(z_vals,-1)

# Since gather preserves order, gather from the second axis.
sorted_features = tf.gather(features, inds, batch_dims = 1)

<tf.Tensor: id=1173, shape=(2, 3, 10), dtype=int32, numpy=
array([[[ 1,  1,  1,  1,  1,  1,  1,  1,  1,  1],
        [ 2,  2,  2,  2,  2,  2,  2,  2,  2,  2],
        [ 3,  3,  3,  3,  3,  3,  3,  3,  3,  3]],

       [[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
        [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0]]], dtype=int32)>

If there is a simpler way, please let me know.

Upvotes: 1

Related Questions