mlworker
mlworker

Reputation: 291

PDF estimation in Scikit-Learn KDE

I am trying to compute PDF estimate from KDE computed using scikit-learn module. I have seen 2 variants of scoring and I am trying both: Statement A and B below.

Statement A results in following error:

AttributeError: 'KernelDensity' object has no attribute 'tree_'

Statement B results in following error:

ValueError: query data dimension must match training data dimension

Seems like a silly error, but I cannot figure out. Please help. Code is below...

from sklearn.neighbors import KernelDensity
import numpy

# d is my 1-D array data
xgrid = numpy.linspace(d.min(), d.max(), 1000)

density = KernelDensity(kernel='gaussian', bandwidth=0.08804).fit(d)

# statement A
density_score = KernelDensity(kernel='gaussian', bandwidth=0.08804).score_samples(xgrid)

# statement B
density_score = density.score_samples(xgrid)

density_score = numpy.exp(density_score)

If it helps, I am using 0.15.2 version of scikit-learn. I've tried this successfully with scipy.stats.gaussian_kde so there is no problem with data.

Upvotes: 6

Views: 4276

Answers (2)

Vahid Mirjalili
Vahid Mirjalili

Reputation: 6501

With statement B, I had the same issue with this error:

 ValueError: query data dimension must match training data dimension

The issue here is that you have 1-D array data, but when you feed it to fit() function, it makes an assumption that you have only 1 data point with many dimensions! So for example, if your training data size is 100000 points, the your d is 100000x1, but fit takes them as 1x100000!!

So, you should reshape your d before fitting: d.reshape(-1,1) and same for xgrid.shape(-1,1)

density = KernelDensity(kernel='gaussian', bandwidth=0.08804).fit(d.reshape(-1,1))
density_score = density.score_samples(xgrid.reshape(-1,1))

Note: The issue with statement A, is that you are using score_samples on an object which is not fit yet!

Upvotes: 10

user1793558
user1793558

Reputation: 31

You need to call the fit() function before you can sample from the distribution.

Upvotes: 1

Related Questions