Soli Technology LLC
Soli Technology LLC

Reputation: 353

How to sort a 4D tensor based on the first value in the 4th dimension?

I have a 4D tensor that I want to sort. The order of values in the 4th dimension is important to stay the same, but I want to sort arrays in the 3rd dimension based on the first value in the 4th dimension. I am using TensorFlow 2.11. I have tried with tf.argsort() and tf.gather_nd(), but I can't make it work.

For example, I have the following tensor:

<tf.Tensor: shape=(1, 6, 4, 2), dtype=int64, numpy=
array([[[[51, 92],
         [14, 71],
         [60, 20],
         [82, 86]],
        [[74, 74],
         [87, 99],
         [23,  2],
         [21, 52]],
        [[ 1, 87],
         [29, 37],
         [ 1, 63],
         [59, 20]],
        [[32, 75],
         [57, 21],
         [88, 48],
         [90, 58]],
        [[41, 91],
         [59, 79],
         [14, 61],
         [61, 46]],
        [[61, 50],
         [54, 63],
         [ 2, 50],
         [ 6, 20]]]])>

I want it to be sorted like this:

<tf.Tensor: shape=(1, 6, 4, 2), dtype=int64, numpy=
array([[[[14, 71],
         [51, 92],
         [60, 20],
         [82, 86]],
        [[21, 52],
         [23,  2],
         [74, 74],
         [87, 99]],
        [[ 1, 87],
         [ 1, 63],
         [29, 37],
         [59, 20]],
        [[32, 75],
         [57, 21],
         [88, 48],
         [90, 58]],
        [[14, 61],
         [41, 91],
         [59, 79],
         [61, 46]],
        [[ 2, 50],
         [ 6, 20],
         [54, 63],
         [61, 50]]]])>

Upvotes: 0

Views: 77

Answers (1)

adrianop01
adrianop01

Reputation: 600

Try argsort() followed by take_along_axis():

import tensorflow as tf
import tensorflow.experimental.numpy as tnp

a=tf.constant(np.array([[[[51, 92],
         [14, 71],
         [60, 20],
         [82, 86]],
        [[74, 74],
         [87, 99],
         [23,  2],
         [21, 52]],
        [[ 1, 87],
         [29, 37],
         [ 1, 63],
         [59, 20]],
        [[32, 75],
         [57, 21],
         [88, 48],
         [90, 58]],
        [[41, 91],
         [59, 79],
         [14, 61],
         [61, 46]],
        [[61, 50],
         [54, 63],
         [ 2, 50],
         [ 6, 20]]]]))

args=tf.reshape(tf.argsort(a[...,0],axis=2),(a.shape[0],a.shape[1],a.shape[2],1))
>>> tf.experimental.numpy.take_along_axis(a, args, axis=2)

<tf.Tensor: shape=(1, 6, 4, 2), dtype=int64, numpy=
array([[[[14, 71],
         [51, 92],
         [60, 20],
         [82, 86]],

        [[21, 52],
         [23,  2],
         [74, 74],
         [87, 99]],

        [[ 1, 87],
         [ 1, 63],
         [29, 37],
         [59, 20]],

        [[32, 75],
         [57, 21],
         [88, 48],
         [90, 58]],

        [[14, 61],
         [41, 91],
         [59, 79],
         [61, 46]],

        [[ 2, 50],
         [ 6, 20],
         [54, 63],
         [61, 50]]]])>

Upvotes: 2

Related Questions