astromax
astromax

Reputation: 6331

Python MeanShift Memory Error

I'm running a clustering algorithm called MeanShift() in the sklearn.cluster module (here are the docs). The object I'm dealing with has 310,057 points distributed in 3-dimensional space. The computer I'm running it on has a total of 128Gb of ram, so when I get the following error, I have a hard time believing that I'm actually using all of it.

[user@host ~]$ python meanshifttest.py
Traceback (most recent call last):
  File "meanshifttest.py", line 13, in <module>
    ms = MeanShift().fit(X)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 280, in fit
    cluster_all=self.cluster_all)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 99, in mean_shift
bandwidth = estimate_bandwidth(X)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/cluster/mean_shift_.py", line 45, in estimate_bandwidth
d, _ = nbrs.kneighbors(X, return_distance=True)
  File "/home/user/anaconda/lib/python2.7/site-packages/sklearn/neighbors/base.py", line 313, in kneighbors
return_distance=return_distance)
  File "binary_tree.pxi", line 1313, in sklearn.neighbors.kd_tree.BinaryTree.query (sklearn/neighbors/kd_tree.c:10007)
  File "binary_tree.pxi", line 595, in sklearn.neighbors.kd_tree.NeighborsHeap.__init__ (sklearn/neighbors/kd_tree.c:4709)
MemoryError

The code I'm running looks like this:

from sklearn.cluster import MeanShift
import asciitable
import numpy as np
import time

data = asciitable.read('./multidark_MDR1_FOFID85000000000_ParticlePos.csv',delimiter=',')
x = [data[i][2] for i in range(len(data))]
y = [data[i][3] for i in range(len(data))]
z = [data[i][4] for i in range(len(data))]
X = np.array(zip(x,y,z))

t0 = time.time()
ms = MeanShift().fit(X)
t1 = time.time()
print str(t1-t0) + " seconds."
labels = ms.labels_
print set(labels)

Would anybody have any ideas about what's happening? Unfortunately I can't switch clustering algorithms because this is the only one I've found which does a good job in addition to accepting no linking lengths/k number of clusters/a priori information.

Thanks in advance!

**UPDATE: I looked into the documentation a little more, and it says the following:

Scalability:

Because this implementation uses a flat kernel and
a Ball Tree to look up members of each kernel, the complexity will is
to O(T*n*log(n)) in lower dimensions, with n the number of samples
and T the number of points. In higher dimensions the complexity will
tend towards O(T*n^2).

Scalability can be boosted by using fewer seeds, for example by using
a higher value of min_bin_freq in the get_bin_seeds function.

Note that the estimate_bandwidth function is much less scalable than
the mean shift algorithm and will be the bottleneck if it is used.

This seems to make some sense, because if you look at the error in detail it is complaining about estimate_bandwidth. Is this an indication that I'm simply using too many particles for the algorithm?

Upvotes: 3

Views: 1425

Answers (1)

Fred Foo
Fred Foo

Reputation: 363487

Judging from the error message, I suspect it's trying to compute all pairwise distances between points, which means it needs 310057² floating point numbers or 716GB of RAM.

You can disable this behavior by giving an explicit bandwidth argument to the MeanShift constructor.

This is arguably a bug; consider filing a bug report for it. (The scikit-learn crew, which includes myself, have recently been working to get rid of these overly expensive distance computations in various places, but apparently no-one looked at meanshift.)

EDIT: the computations above were off by factor of 3, but the memory usage was indeed quadratic. I just fixed this in the dev version of scikit-learn.

Upvotes: 5

Related Questions