Perl Del Rey
Perl Del Rey

Reputation: 1059

How does argmax work when given a 3d tensor - tensorflow

I was wondering how does argmax work when given a 3D tensor. I know what happens when it has a 2D tesnor but 3D is confusing me a lot.

Example:

import tensorflow as tf
import numpy as np

sess = tf.Session()

coordinates = np.random.randint(0, 100, size=(3, 3, 2))
coordinates
Out[20]: 
array([[[15, 23],
        [ 3,  1],
        [80, 56]],
       [[98, 95],
        [97, 82],
        [10, 37]],
       [[65, 32],
        [25, 39],
        [54, 68]]])
sess.run([tf.argmax(coordinates, axis=1)])
Out[21]: 
[array([[2, 2],
        [0, 0],
        [0, 2]], dtype=int64)]



Upvotes: 1

Views: 496

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36624

tf.argmax returns the index of the maximum value, as per the axis specified. The specified axis will be crushed, and the index of the maximum value of every unit will be returned. The returned shape will have the same shape, except the the specified axis that will disappear. I'll make examples with tf.reduce_max so we can follow the values.

Let's start with your array:

x = np.array([[[15, 23],
               [3, 1],
               [80, 56]],
              [[98, 95],
               [97, 82],
               [10, 37]],
              [[65, 32],
               [25, 39],
               [54, 68]]])

see tf.reduce_max(x, axis=0)

            ([[[15, 23],
                          [3, 1],
                                     [80, 56]],
              [[98, 95],               ^
                 ^   ^    [97, 82],
                            ^  ^     [10, 37]],
              [[65, 32],
                          [25, 39],
                                     [54, 68]]]) 
                                           ^    
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[98, 95],
       [97, 82],
       [80, 68]])>

now tf.reduce_max(x, 1)

            ([[[15, 23], [[98, 95],  [[65, 32],
                            ^   ^       ^
               [3, 1],    [97, 82],   [25, 39],
           
               [80, 56]], [10, 37]],  [54, 68]]])
                 ^   ^                      ^
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[80, 56],
       [98, 95],
       [65, 68]])>

now tf.reduce_max(x, axis=2)

            ([[[15, 23],
                     ^
               [3, 1],
                ^
               [80, 56]],
                ^   
              [[98, 95],
                 ^
               [97, 82],
                 ^
               [10, 37]],
                     ^
              [[65, 32],
                 ^
               [25, 39],
                     ^
               [54, 68]]])
                     ^
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[23,  3, 80],
       [98, 97, 37],
       [65, 39, 68]])>

Upvotes: 1

Related Questions