Girish Kumar Chandora
Girish Kumar Chandora

Reputation: 188

How does Sklearn Naive Bayes Bernoulli Classifier work when the predictors are not binary?

As we know the Bernoulli Naive Bayes Classifier uses binary predictors (features). The thing I am not getting is how BernoulliNB in scikit-learn is giving results even if the predictors are not binary. The following example is taken verbatim from the documentation:

import numpy as np
rng = np.random.RandomState(1)
X = rng.randint(5, size=(6, 100))
Y = np.array([1, 2, 3, 4, 4, 5])
from sklearn.naive_bayes import BernoulliNB
clf = BernoulliNB()
clf.fit(X, Y)

print(clf.predict(X[2:3]))

Output:

array([3])

Here are the first 10 features of X, and they are obviously not binary:

3   4   0   1   3   0   0   1   4   4   1
1   0   2   4   4   0   4   1   4   1   0
2   4   4   0   3   3   0   3   1   0   2
2   2   3   1   4   0   0   3   2   4   1
0   4   0   3   2   4   3   2   4   2   4
3   3   3   3   0   2   3   1   3   2   3

How does BernoulliNB work here even though the predictors are not binary?

Upvotes: 1

Views: 2330

Answers (1)

desertnaut
desertnaut

Reputation: 60399

This is due to the binarize argument; from the docs:

binarize : float or None, default=0.0

Threshold for binarizing (mapping to booleans) of sample features. If None, input is presumed to already consist of binary vectors.

When called with its default value binarize=0.0, as is the case in your code (since you do not specify it explicitly), it will result in converting every element of X greater than 0 to 1, hence the transformed X that will be used as the actual input to the BernoulliNB classifier will consist indeed of binary values.

The binarize argument works exactly the same way with the stand-alone preprocessing function of the same name; here is a simplified example, adapting your own:

from sklearn.preprocessing import binarize
import numpy as np

rng = np.random.RandomState(1)
X = rng.randint(5, size=(6, 1))
X
# result
array([[3],
       [4],
       [0],
       [1],
       [3],
       [0]])

binarize(X) # here as well, default threshold=0.0
# result (binary values):
array([[1],
       [1],
       [0],
       [1],
       [1],
       [0]])

Upvotes: 2

Related Questions