Gilfoyle
Gilfoyle

Reputation: 3636

Simple k-means algorithm in Python

The following is a very simple implementation of the k-means algorithm.

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)

DIM = 2
N = 2000
num_cluster = 4
iterations = 3

x = np.random.randn(N, DIM)
y = np.random.randint(0, num_cluster, N)

mean = np.zeros((num_cluster, DIM))
for t in range(iterations):
    for k in range(num_cluster):
        mean[k] = np.mean(x[y==k], axis=0)
    for i in range(N):
        dist = np.sum((mean - x[i])**2, axis=1)
        pred = np.argmin(dist)
        y[i] = pred

for k in range(num_cluster):
    plt.scatter(x[y==k,0], x[y==k,1])
plt.show()

Here are two example outputs the code produces:

enter image description here

enter image description here

The first example (num_cluster = 4) looks as expected. The second example (num_cluster = 11) however shows only on cluster which is clearly not what I wanted. The code works depending on the number of classes I define and the number of iterations.

So far, I couldn't find the bug in the code. Somehow the clusters disappear but I don't know why.

Does anyone see my mistake?

Upvotes: 1

Views: 2180

Answers (3)

ernest
ernest

Reputation: 161

K-means is also pretty sensitive to initial conditions. That said, k-means can and will drop clusters (but dropping to one is weird). In your code, you assign random clusters to the points.

Here's the problem: if I take several random subsamples of your data, they're going to have about the same mean point. Each iteration, the very similar centroids will be close to each other and more likely to drop.

Instead, I changed your code to pick num_cluster number of points in your data set to use as the initial centroids (higher variance). This seems to produce more stable results (didn't observe the dropping to one cluster behavior over several dozen runs):

import numpy as np
import matplotlib.pyplot as plt

DIM = 2
N = 2000
num_cluster = 11
iterations = 3

x = np.random.randn(N, DIM)
y = np.zeros(N)
# initialize clusters by picking num_cluster random points
# could improve on this by deliberately choosing most different points
for t in range(iterations):
    if t == 0:
        index_ = np.random.choice(range(N),num_cluster,replace=False)
        mean = x[index_]
    else:
        for k in range(num_cluster):
            mean[k] = np.mean(x[y==k], axis=0)
    for i in range(N):
        dist = np.sum((mean - x[i])**2, axis=1)
        pred = np.argmin(dist)
        y[i] = pred

for k in range(num_cluster):
    fig = plt.scatter(x[y==k,0], x[y==k,1])
plt.show()

Ran a few dozen times with consistent results

Upvotes: 2

brezniczky
brezniczky

Reputation: 542

It does seem that there are NaN's entering the picture. Using a seed=1, iterations=2, the number of clusters reduce from the initial 4 to effectively 3. In the next iteration this technically plummets to 1.

The NaN mean coordinates of the problematic centroid then result in weird things. To rule out those problematic clusters which became empty, one (possibly a bit too lazy) option is to set the related coordinates to Inf, thereby making it a "more distant than any other" point than those still in the game (as long as the 'input' coordinates cannot be Inf). The below snippet is a quick illustration of that and a few debug messages that I used to peek into what was going on:

[...]
for k in range(num_cluster):
    mean[k] = np.mean(x[y==k], axis=0)
    # print mean[k]
    if any(np.isnan(mean[k])):
        # print "oh no!"
        mean[k] = [np.Inf] * DIM
[...]

With this modification the posted algorithm seems to work in a more stable fashion (i.e. I couldn't break it so far).

Please also see the Quora link also mentioned among the comments about the split opinions, and the book "The Elements of Statistical Learning" for example here - the algorithm is not too explicitly defined there either in the relevant respect.

Upvotes: 1

ShlomiF
ShlomiF

Reputation: 2905

You're getting one cluster because there really is only one cluster.
There's nothing in your code to avoid clusters disappearing, and the truth is that this will happen also for 4 clusters but after more iterations.
I ran your code with 4 clusters and 1000 iterations and they all got swallowed up in the one big and dominant cluster.
Think about it, your large cluster passes a critical point, and just keeps growing because other points are gradually becoming closer to it than to their previous mean.
This will not happen in the case that you reach an equilibrium (or stationary) point, in which nothing moves between clusters. But it's obviously a bit rare, and more rare the more clusters you're trying to estimate.


A clarification: The same thing can happen also when there are 4 "real" clusters and you're trying to estimate 4 clusters. But that would mean a rather nasty initialization and can be avoided by intelligently aggregating multiple randomly seeded runs.
There are also common "tricks" like taking the initial means to be far apart, or at the centers of different pre-estimated high density locations, etc. But that's starting to get involved, and you should read more deeply about k-means for that purpose.

Upvotes: 2

Related Questions