ThisGuy
ThisGuy

Reputation: 873

How does K-Means work?

I'm lost and confused on how K-Means work. What I know so far is

  1. plot n points, say 7 points
  2. randomly choose k points, say 3 points which will serve as centroids
  3. centroids will be the type of classes so we have 3 classes
  4. pick a point to classify its class
  5. nearest class from the chosen 3 will classify the point picked

I have already implemented getting a text file which contains the points. Once I select a file, the points will be plotted. And now I stopped there.

Here's what I want to know:

1.I want to know the next things that I should do after plotting the points because I'm not sure of the algorithm that I stated above.

2.And I want to know how the iteration works, iteration for getting the final classes of each of the point. I'm confused because I don't know how the picked point change class if it gets the class from the nearest point with class

Any help would be much appreciated.

Upvotes: 1

Views: 8031

Answers (1)

mattnedrich
mattnedrich

Reputation: 8072

The input to K-Means is a set of points (observations), and an integer K. The goal is to partition the input points into K distinct sets (clusters).

The first step is to initialize the algorithm by choosing K initial cluster centroid locations. This is typically done by randomly choosing K points from the input set. With these initial K centroids, the algorithm proceeds by repeating the following two main steps:

1) Cluster Assignment - Here, each observation (e.g., each point in the data set) is assigned to a cluster centroid such that the WCSS objective function is minimized. This can often be translated to assigning each observation to the closest cluster centroid (which coincidentally minimizes WCSS for many distance metrics), though for some distance metrics and spaces this need not be the case.
2) Update Centroids - After all of the input observations have been assigned to a cluster centroid, each centroid is re-computed. For each cluster, the new centroid is computed by averaging the observations that were assigned to it (e.g., computing the 'mean' of the observations).

These steps are repeated until the algorithm "converges". There are several ways to detect convergence. The most typical way is to run until none of the observations change cluster membership. As an additional tip, if you compute the WCSS (explained below) for each iteration, you should see it decrease (e.g., the error should become smaller as the algorithm runs). If not, you're implementation probably has a bug.

It's also important to understand that K-Means is notorious for getting stuck in local minima. This means that the final result may not be the best result. To overcome this, K-Means is often run many times with different starting points (initial centroids), and the run with the lowest error is chosen.

Within-cluster sum of squares (WCSS) is used to measure the error (it's explained on Wikipedia: http://en.wikipedia.org/wiki/K-means_clustering). WCSS is computed as

totalError = 0;
foreach(Point p in inputData)
{
    // compute p's error
    pError = someDistanceFunc(p, p_centroid)^2 
    totalError += pError;
}

Essentially, for each point you compute an error measure based on how close it is to it's centroid. All of these errors are added up to compute the total error.

There is a lot of K-Means information available on the internet. For a more in-depth description I recommend Andrew Ng's Coursera lectures: http://www.youtube.com/watch?v=Ao2vnhelKhI

Upvotes: 12

Related Questions