Reputation: 791
I want to sort a tensor based on the sorted order of another tensor which shares the same shape except for the last axis. For example, if I have some meaningful values associated with feature vectors, then I would like to sort the feature vectors based upon these values. See the example:
# Meaningful values with shape (2,3)
z_vals = tf.constant([[3,1,2],[-2,-1,0]])
# I want to have these features sorted based on the ascending order of the above
features = tf.broadcast_to(z_vals[...,None], (2,3,10))
# Desired result
array([[[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
[[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=int32)>
Thus, I want to have features
sorted such that it follows the sorted order of z_vals
.
Upvotes: 2
Views: 1480
Reputation: 791
I think I have a solution that works, but I'm not sure if there is a much simpler way to do this.
What I do is:
# Meaningful values with shape (2,3)
z_vals = tf.constant([[3,1,2],[-2,-1,0]])
features = tf.broadcast_to(z_vals[...,None], (2,3,10))
# Get sorted indices along last axis
inds = tf.argsort(z_vals,-1)
# Since gather preserves order, gather from the second axis.
sorted_features = tf.gather(features, inds, batch_dims = 1)
<tf.Tensor: id=1173, shape=(2, 3, 10), dtype=int32, numpy=
array([[[ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3]],
[[-2, -2, -2, -2, -2, -2, -2, -2, -2, -2],
[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1],
[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]], dtype=int32)>
If there is a simpler way, please let me know.
Upvotes: 1