Reputation: 119
What is the threshold value that is used by TF by default to classify an input image as being a certain class?
For example, say I have 3 classes 0
, 1
, 2
, and the labels for images are one-hot encoded like so: [1, 0, 0]
, meaning this image has label of class 0.
Now when a model outputs a prediction after softmax like this one: [0.39, 0.56, 0.05]
does TF use 0.5 as the threshold so the class it predicts is class 1?
What if all the predictions were below 0.5 like [0.33, 0.33, 0.33]
what would TF say the result is?
And is there any way to specify a new threshold for example 0.7 and ensure TF says that a prediction is wrong if no class prediction is above that threshold?
Also would this logic carry over to the inference stage too where if the network is uncertain of the class then it will refuse to give a classification for the image?
Upvotes: 6
Views: 9502
Reputation: 60319
when a model outputs a prediction after softmax like this one:
[0.39, 0.56, 0.05]
does TF use 0.5 as the threshold so the class it predicts is class 1?
No. There is not any threshold involved here. Tensorflow (and any other framework, for that matter) will just pick up the maximum one (argmax
); the result here (class 1
) would be the same even if the probabilistic output was [0.33, 0.34, 0.33]
.
You seem to erroneously believe that a probability value of 0.5 has some special significance in a 3-class classification problem; it has not: a probability value of 0.5 is "special" only in a binary classification setting (and a balanced one, for that matter). In an n
-class setting, the respective "special" value is 1/n
(here 0.33), and by definition, there will always be some entry in the probability vector greater than or equal to this value.
What if all the predictions were below 0.5 like
[0.33, 0.33, 0.33]
what would TF say the result is?
As already implied, there is nothing strange or unexpected with all probabilities being below 0.5 in an n-class problem with n>2.
Now, if all the probabilities happen to be equal, as in the example you show (although highly improbable in practice, the question is valid, at least in theory), ideally, such ties should be resolved randomly (i.e. pick a class in random); in practice, since usually this stage is handled by the argmax
method of Numpy, the prediction will be the first class (i.e. class 0
), which is not difficult to demonstrate:
import numpy as np
x = np.array([0.33, 0.33, 0.33])
np.argmax(x)
# 0
due to how such cases are handled by Numpy - from the argmax
docs:
In case of multiple occurrences of the maximum values, the indices corresponding to the first occurrence are returned.
To your next question:
is there any way to specify a new threshold for example 0.7 and ensure TF says that a prediction is wrong if no class prediction is above that threshold?
Not in Tensorflow (or any other framework) itself, but this is always something that can be done in a post-processing stage during inference: irrespectively of what is actually returned by your classifier, it is always possible to add some extra logic such that whenever the max probability value is less that a threshold, your system (i.e. your model plus the post-processing logic) returns something like "I don't know / I am not sure / I can't answer". But again, this is external to Tensorflow (or any other framework used) and the model itself, and it can be used only during inference and not during training (in any case, it doesn't make sense during training, because during training only predicted class probabilities are used, and not hard classes).
In fact, we had implemented such a post-processing module in a toy project some years ago, which was an online service to classify dog races from images: when the max probability returned by the model was less than a threshold (which was the case, say, when the model was presented with an image of a cat instead of a dog), the system was programmed to respond with the question "Are you sure this is a dog"?, instead of being forced to make a prediction among the predefined dog races...
Upvotes: 8
Reputation: 808
the threshold is used in the case of binary classification or multilabel classification, in the case of multi class classification you use argmax, basically the class with the highest activation is your output class, all classes rarely equal each other, if the model is trained well there should be one dominant class
Upvotes: 3