Reputation: 17478
In the docs, the predict_proba(self, x, batch_size=32, verbose=1)
is
Generates class probability predictions for the input samples batch by batch.
and returns
A Numpy array of probability predictions.
Suppose my model is binary classification model, does the output is [a, b]
, for a
is probability of class_0
, and b
is the probability of class_1
?
Upvotes: 4
Views: 11978
Reputation: 40516
Here the situation is different and somehow misleading, especially when you are comparing predict_proba
method to sklearn
methods with the same name. In Keras (not sklearn wrappers) a method predict_proba
is exactly the same as a predict
method. You can even check it here:
def predict_proba(self, x, batch_size=32, verbose=1):
"""Generates class probability predictions for the input samples
batch by batch.
# Arguments
x: input data, as a Numpy array or list of Numpy arrays
(if the model has multiple inputs).
batch_size: integer.
verbose: verbosity mode, 0 or 1.
# Returns
A Numpy array of probability predictions.
"""
preds = self.predict(x, batch_size, verbose)
if preds.min() < 0. or preds.max() > 1.:
warnings.warn('Network returning invalid probability values. '
'The last layer might not normalize predictions '
'into probabilities '
'(like softmax or sigmoid would).')
return preds
So - in a binary classification case - the output which you get depends on the design of your network:
predict_proba
is simply a probability assigned to class 1.softmax
function - then the output of predict_proba
is a pair where [a, b]
where a = P(class(x) = 0)
and b = P(class(x) = 1)
.This second method is rarely used and there are some theorethical advantages of using the first method - but I wanted to inform you - just in case.
Upvotes: 13
Reputation: 4159
It depends on how you specify output of your model and your targets. It can be both. Usually when one is doing binary classification the output is a single value which is a probability of the positive prediction. One minus the output is probability of the negative prediction.
Upvotes: 0