Reputation: 1059
I was wondering how does argmax work when given a 3D tensor. I know what happens when it has a 2D tesnor but 3D is confusing me a lot.
Example:
import tensorflow as tf
import numpy as np
sess = tf.Session()
coordinates = np.random.randint(0, 100, size=(3, 3, 2))
coordinates
Out[20]:
array([[[15, 23],
[ 3, 1],
[80, 56]],
[[98, 95],
[97, 82],
[10, 37]],
[[65, 32],
[25, 39],
[54, 68]]])
sess.run([tf.argmax(coordinates, axis=1)])
Out[21]:
[array([[2, 2],
[0, 0],
[0, 2]], dtype=int64)]
Upvotes: 1
Views: 496
Reputation: 36624
tf.argmax
returns the index of the maximum value, as per the axis specified. The specified axis will be crushed, and the index of the maximum value of every unit will be returned. The returned shape will have the same shape, except the the specified axis that will disappear. I'll make examples with tf.reduce_max
so we can follow the values.
Let's start with your array:
x = np.array([[[15, 23],
[3, 1],
[80, 56]],
[[98, 95],
[97, 82],
[10, 37]],
[[65, 32],
[25, 39],
[54, 68]]])
see tf.reduce_max(x, axis=0)
([[[15, 23],
[3, 1],
[80, 56]],
[[98, 95], ^
^ ^ [97, 82],
^ ^ [10, 37]],
[[65, 32],
[25, 39],
[54, 68]]])
^
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[98, 95],
[97, 82],
[80, 68]])>
now tf.reduce_max(x, 1)
([[[15, 23], [[98, 95], [[65, 32],
^ ^ ^
[3, 1], [97, 82], [25, 39],
[80, 56]], [10, 37]], [54, 68]]])
^ ^ ^
<tf.Tensor: shape=(3, 2), dtype=int32, numpy=
array([[80, 56],
[98, 95],
[65, 68]])>
now tf.reduce_max(x, axis=2)
([[[15, 23],
^
[3, 1],
^
[80, 56]],
^
[[98, 95],
^
[97, 82],
^
[10, 37]],
^
[[65, 32],
^
[25, 39],
^
[54, 68]]])
^
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
array([[23, 3, 80],
[98, 97, 37],
[65, 39, 68]])>
Upvotes: 1