Reputation: 2955
I'm newbie to tensorflow
and I'm trying to get the index of the maximum value in a Tensor. Here is the code:
def select(input_layer):
shape = input_layer.get_shape().as_list()
rel = tf.nn.relu(input_layer)
print (rel)
redu = tf.reduce_sum(rel,3)
print (redu)
location2 = tf.argmax(redu, 1)
print (location2)
sess = tf.InteractiveSession()
I = tf.random_uniform([32, 3, 3, 5], minval = -541, maxval = 23, dtype = tf.float32)
matI, matO = sess.run([I, select(I, 3)])
print(matI, matO)
Here is the output:
Tensor("Relu:0", shape=(32, 3, 3, 5), dtype=float32)
Tensor("Sum:0", shape=(32, 3, 3), dtype=float32)
Tensor("ArgMax:0", shape=(32, 3), dtype=int64)
...
Because of dimension=1 in the argmax
function the shape of Tensor("ArgMax:0") = (32,3)
. Is there any way to get a argmax
output tensor size = (32,)
without doing reshape
before applying the argmax
?
Upvotes: 1
Views: 1623
Reputation: 24581
You problably don't want an output of size (32,)
because when you argmax
along several directions, you usually want to have the coordinates of the max for all the reduced dimensions. In your case, you would want to have an output of size (32,2)
.
You can do a two-dimensional argmax
like this:
import numpy as np
import tensorflow as tf
x = np.zeros((10,9,8))
# pick a random position for each batch image that we set to 1
pos = np.stack([np.random.randint(9,size=10), np.random.randint(8,size=10)])
posext = np.concatenate([np.expand_dims([i for i in range(10)], axis=0), pos])
x[tuple(posext)] = 1
a = tf.argmax(tf.reshape(x, [10, -1]), axis=1)
pos2 = tf.stack([a // 8, tf.mod(a, 8)]) # recovered positions, one per batch image
sess = tf.InteractiveSession()
# check that the recovered positions are as expected
assert (pos == pos2.eval()).all(), "it did not work"
Upvotes: 2