KalvinB
KalvinB

Reputation: 337

Tensorflow 2.2.0 gather maximum elements over one dimension

I have the following problem:

I have a tensor of shape (1600, 29) and I want to obtain the maximum of axis 1. The result should be a (1600, 1) tensor. For simplification I will use a (5,3) tensor to demonstrate my problem:

B = tf.constant([[2, 20, 30],
                 [2, 7, 6],
                 [3, 11, 16],
                 [19, 1, 8],
                 [14, 45, 23]])

x = x = tf.math.argmax(B, 1) --> 5 Values [2 1 2 0 1]
z = tf.gather(B, x, axis=1) --> Shape: (5,5) [[30 20 30  2 20]
                                              [ 6  7  6  2  7]
                                              [16 11 16  3 11]
                                              [ 8  1  8 19  1]
                                              [23 45 23 14 45]]

Okay, now, x gives me the maximum element over my axis, however, tf.gather does not return [30, 7, 16, 19, 45], but some weird tensor. How do I "reduce" the dimension properly?

My "rather" dirty way is this:

eye_z = tf.eye(5, 5)

intermed_result = z*eye_z

result = tf.linalg.matvec(intermed_result,tf.transpose(tf.constant([1,1,1,1,1], dtype=tf.float32)))

Which results in the correct tensor: [30. 7. 16. 19. 45.]

Upvotes: 0

Views: 88

Answers (1)

yatu
yatu

Reputation: 88276

You have tf.math.reduce_max for this:

m = tf.math.reduce_max(B, axis=1)

sess = tf.InteractiveSession()
sess.run(m)
# array([30,  7, 16, 19, 45])

Upvotes: 0

Related Questions