bkshi
bkshi

Reputation: 300

Tensorflow : tf.argmax() as prediction or the maximum value?

I am learning tensorflow and in various examples I have seen that to get predictions from logits we use tf.argmax(logits, 1). According to what I under stand logits are the probability values and tf.argmax() will give the index of the maximum value in the specified axis. But, how can we use the indices in place of the probability values. Shouldn't we use the maximum value as prediction ?

But I have seen that the above code works fine. I am sure that i'm missing some basics here. Can anybody clear this out with an example ?

Upvotes: 3

Views: 11894

Answers (1)

nessuno
nessuno

Reputation: 27042

Usually logits is the output tensor of a classification network, whose content is the unnormalized (not scaled between 0 and 1) probabilities.

tf.argmax gives you the index of maximum value along the specified axis.

You can convert logits to a pseudo-probability (that's just a tensor whose values sum up to 1) and feed it as input to argmax:

top = tf.argmax(tf.nn.softmax(logits), 1)

but in the end, the result is the same as feeding directly the unnormalized probabilities:

top = tf.argmax(logits, 1)

However, you have to use argmax in order to understand which is the class that the network predicted for that input, this is the only way, you can't use just the probabilities (normalized or unnormalized).

Just think about a logits tensor like:

logits = [ [ 10, 500, -1, 0.5, 12 ] ]

The tensor shape is [1, 5]. Just looking at the tensor values, you can easily understand that the class with the highest confidence is the one associated to the position 1, with value 500.

How can you extract the position of the highest value? You have to use argmax:

top = tf.argmax(logits, 1)

Once executed it will return the value 1

Summary: The values of logits are Scores, and the indices are Classes. By using argmax(), you can obtain the predicted class

Upvotes: 17

Related Questions