Frank
Frank

Reputation: 671

partial_fit with scikit-learn returns ValueError: The sum of the priors should be 1

I am trying to run a sklearn.naive_bayes.GaussianNB model with partial_fit. For this I calculate the priors like this:

unique_lbls, counts = np.unique(labels, return_counts=True)
counts = counts.astype(float)
priors = counts / counts.sum()
model  = GaussianNB(priors=priors)
model.partial_fit(X, y, classes=unique_lbls)

I get an `ValueError: The sum of the priors should be 1, but I have checked and the priors do sum up to 1.0:

print priors.sum()
> 1.0

I am using the following versions:

Python 2.7.12
scikit-learn 0.18.2
numpy 1.13.1

I can only imagine that it comes down to sensitivity of the summed value, but I have tried to normalize the priors again with priors /= priors.sum() and it returns the same error.

Is there a different way to make sure that the priors sum to 1.0 with a higher tolerance, or is there some (to me not-)obvious reason this doesn't work?

Edit: labels is a numpy array with containing the whole data set's labels represented as integers, X and y are a batch of the full data set. y and labels both have at least 100 examples from each class.

Upvotes: 0

Views: 217

Answers (1)

MB-F
MB-F

Reputation: 23647

My first intuition was that something is wrong with the data. However, it looks like the partial_fit function does not even look at the data before raising that error. In particular, the implementation looks like this:

# Check that the sum is 1
if priors.sum() != 1.0:
    raise ValueError('The sum of the priors should be 1.')

They compare the sum of the priors exactly to 1.0, which is numerically not very robust. If you have an unlucky combination of values the normalized priors may not sum precisiely to 1.0. Consider this:

priors = np.array([1, 2, 3, 4, 5, 6], dtype=float)
priors /= priors.sum()
print(priors.sum() == 1.0)  # False

Such a situation will make the check fail. Let's try to fix this:

priors[0] = 1.0 - priors[1:].sum()
print(priors.sum() == 1.0)  # True

Upvotes: 1

Related Questions