VanThaoNguyen
VanThaoNguyen

Reputation: 812

Another function useful than Elbow in finding k-clusters

I try to find the appropriate k clusters for k-means method in Machine Learning. I used Elbow method, but it time consuming and high complexity. Can anyone tell me another method to replace it. Thanks so much

Upvotes: 1

Views: 820

Answers (1)

stackoverflowuser2010
stackoverflowuser2010

Reputation: 40869

A metric that you can use to evaluate the result of a clustering is the silhouette coefficient. This value basically computes:

silhouette coefficient = 1 - (intra-cluster cohesion) / (inter-cluster separation)

The value ranges from -1 to +1, but generally you want values closer to 1.0. So if you run a clustering algorithm (e.g. k-means or hierarchical clustering) to produce 3 clusters, you can call a silhouette library to compute a silhouette coefficient value, e.g. 0.50. If you run your algorithm again to produce 4 clusters, you may compute another silhouette coefficient value, e.g. 0.55. You could then conclude that 4 clusters is a better clustering because it has a higher silhouette coefficient.

Below is an example data set where I've created three different clusters in 2-D space using R. NOTE: Real-world data will never look this clean with such obvious separation between clusters. Even simple data like Fisher's Iris data set has overlap between labelled clusters.

enter image description here

You can then use R's silhouette library to compute the silhouette coefficient. (More information can be found at the STHDA website.) Below are plots of the silhouette information. The one metric that you want is in the lower-left corner, which says "Average silhouette width: xxx". That value is the average of all the horizontal bars.

Here is the silhouette coefficient for K=2 clusters.

plot(silhouette(kmeans(df, centers=2)$cluster, dist(df)))

enter image description here

Here is the silhouette coefficient for K=3 clusters.

plot(silhouette(kmeans(df, centers=3)$cluster, dist(df)))

enter image description here

Here is the silhouette coefficient for K=4 clusters.

plot(silhouette(kmeans(df, centers=4)$cluster, dist(df)))

enter image description here

From looking at the silhouette coefficients, you can conclude that K=3 clusters is the best clustering because it has the highest silhouette coefficient.

You can find the best K value programmatically by simply sweeping through the multiple K value candidates (e.g. between 2 and 10) while keeping track of the highest silhouette coefficient found. Below I've done just that while also building a plot of silhouette coefficient (y-axis) vs. K (x-axis). The output says:

Best Silhouette coefficient=0.888926 occurred at k=3

enter image description here

library(cluster) # for silhouette
library(ggplot2) # for ggplot
library(scales) # for pretty_breaks


# Create sample 2-D data set with clusters around the points (1,1), (2,4), and (3,1)
x<- c(rnorm(n=25, mean=1,sd=.1), rnorm(n=25,mean=2,sd=.1),rnorm(n=25,mean=3,sd=.2))
y<- c(rnorm(n=25, mean=1,sd=.1), rnorm(n=25,mean=4,sd=.1),rnorm(n=25,mean=1,sd=.2))

df <- data.frame(x=x, y=y)

xMax <- max(x)
yMax <- max(y)
print(ggplot(df, aes(x,y)) + geom_point() + xlim(0, max(xMax, yMax)) + ylim(0, max(xMax,yMax)))


# Use the Iris data set.
#df <- subset(iris, select=-c(Species))
#df <- scale(df)


# Run through multiple candidate values of K clusters.

xValues <- c() # Holds the kvalues (x-axis)
yValues <- c() # Holds the silhouette coefficient values (y-axis)
bestKValue <- 0
bestSilhouetteCoefficient <- 0

kSequence <- seq(2, 5)

for (kValue in kSequence) {

    xValues <- append(xValues, kValue)
    kmeansResult <- kmeans(df, centers=kValue, nstart=5)
    silhouetteResult <- silhouette(kmeansResult$cluster, dist(df))
    silhouetteCoefficient <- mean(silhouetteResult[,3])
    yValues <- append(yValues, silhouetteCoefficient)

    if (silhouetteCoefficient > bestSilhouetteCoefficient) {
        bestSilhouetteCoefficient <- silhouetteCoefficient
        bestKValue <- kValue
    }
}

# Create a dataframe for ggplot to plot the accumulated silhouette values.
dfSilhouette <- data.frame(k=xValues, silhouetteCoefficient=yValues)

# Create the ggplot line plot for silhouette coefficient.
silhouettePlot<- ggplot(data=dfSilhouette, aes(k)) +
    geom_line(aes(y=silhouetteCoefficient)) +
    xlab("k") +
    ylab("Average silhouette width") +
    ggtitle("Average silhouette width") +
    scale_x_continuous(breaks=pretty_breaks(n=20)) 

print(silhouettePlot)

printf <- function(...) cat(sprintf(...))
printf("Best Silhouette coefficient=%f occurred at k=%d", bestSilhouetteCoefficient, bestKValue )

Note that I've used the printf function from the answer here.

A related question to yours is here.

Upvotes: 4

Related Questions