Reputation: 353
I have a 4D tensor that I want to sort. The order of values in the 4th dimension is important to stay the same, but I want to sort arrays in the 3rd dimension based on the first value in the 4th dimension. I am using TensorFlow 2.11. I have tried with tf.argsort() and tf.gather_nd(), but I can't make it work.
For example, I have the following tensor:
<tf.Tensor: shape=(1, 6, 4, 2), dtype=int64, numpy=
array([[[[51, 92],
[14, 71],
[60, 20],
[82, 86]],
[[74, 74],
[87, 99],
[23, 2],
[21, 52]],
[[ 1, 87],
[29, 37],
[ 1, 63],
[59, 20]],
[[32, 75],
[57, 21],
[88, 48],
[90, 58]],
[[41, 91],
[59, 79],
[14, 61],
[61, 46]],
[[61, 50],
[54, 63],
[ 2, 50],
[ 6, 20]]]])>
I want it to be sorted like this:
<tf.Tensor: shape=(1, 6, 4, 2), dtype=int64, numpy=
array([[[[14, 71],
[51, 92],
[60, 20],
[82, 86]],
[[21, 52],
[23, 2],
[74, 74],
[87, 99]],
[[ 1, 87],
[ 1, 63],
[29, 37],
[59, 20]],
[[32, 75],
[57, 21],
[88, 48],
[90, 58]],
[[14, 61],
[41, 91],
[59, 79],
[61, 46]],
[[ 2, 50],
[ 6, 20],
[54, 63],
[61, 50]]]])>
Upvotes: 0
Views: 77
Reputation: 600
Try argsort()
followed by take_along_axis()
:
import tensorflow as tf
import tensorflow.experimental.numpy as tnp
a=tf.constant(np.array([[[[51, 92],
[14, 71],
[60, 20],
[82, 86]],
[[74, 74],
[87, 99],
[23, 2],
[21, 52]],
[[ 1, 87],
[29, 37],
[ 1, 63],
[59, 20]],
[[32, 75],
[57, 21],
[88, 48],
[90, 58]],
[[41, 91],
[59, 79],
[14, 61],
[61, 46]],
[[61, 50],
[54, 63],
[ 2, 50],
[ 6, 20]]]]))
args=tf.reshape(tf.argsort(a[...,0],axis=2),(a.shape[0],a.shape[1],a.shape[2],1))
>>> tf.experimental.numpy.take_along_axis(a, args, axis=2)
<tf.Tensor: shape=(1, 6, 4, 2), dtype=int64, numpy=
array([[[[14, 71],
[51, 92],
[60, 20],
[82, 86]],
[[21, 52],
[23, 2],
[74, 74],
[87, 99]],
[[ 1, 87],
[ 1, 63],
[29, 37],
[59, 20]],
[[32, 75],
[57, 21],
[88, 48],
[90, 58]],
[[14, 61],
[41, 91],
[59, 79],
[61, 46]],
[[ 2, 50],
[ 6, 20],
[54, 63],
[61, 50]]]])>
Upvotes: 2