Frederic Bastiat
Frederic Bastiat

Reputation: 693

efficient way of accessing data grouped by KMeans clusters

I am trying to draw circles around each centroid with radius extending to the furthest out point belonging to each cluster. Right now my circles are drawn with radius extending to the point in the entire training data set furthest from the cluster center

Here is my code:

def KMeansModel(n):
    pca = PCA(n_components=2)
    reduced_train_data = pca.fit_transform(train_data)
    KM = KMeans(n_clusters=n)
    KM.fit(reduced_train_data)
    plt.plot(reduced_train_data[:, 0], reduced_train_data[:, 1], 'k.', markersize=2)
    centroids = KM.cluster_centers_
    # Plot the centroids as a red X
    plt.scatter(centroids[:, 0], centroids[:, 1],
                marker='x', color='r')
    for i in centroids:
        print np.max(metrics.pairwise_distances(i, reduced_train_data))
        plt.gca().add_artist(plt.Circle(i, np.max(metrics.pairwise_distances(i, reduced_train_data)), fill=False))
    plt.show()

out = [KMeansModel(n) for n in np.arange(1,16,1)]

Upvotes: 1

Views: 1804

Answers (1)

Miriam Farber
Miriam Farber

Reputation: 19664

When you do

metrics.pairwise_distances(i, reduced_train_data)

you calculate the distance from all the training points, and not only the training points from the relevant class. In order to find the positions of the points from the training data that correspond to class ind, you can do

np.where(KM.labels_==ind)[0]

Thus, inside the for loop

for i in centroids:

you need to access to the training points from the specific class. This will do the job:

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn import metrics
import matplotlib.pyplot as plt
import numpy as np

def KMeansModel(n):
    pca = PCA(n_components=2)
    reduced_train_data = pca.fit_transform(train_data)
    KM = KMeans(n_clusters=n)
    KM.fit(reduced_train_data)
    plt.plot(reduced_train_data[:, 0], reduced_train_data[:, 1], 'k.', markersize=2)
    centroids = KM.cluster_centers_
    # Plot the centroids as a red X
    plt.scatter(centroids[:, 0], centroids[:, 1],
                marker='x', color='r')
    for ind,i in enumerate(centroids):
        class_inds=np.where(KM.labels_==ind)[0]
        max_dist=np.max(metrics.pairwise_distances(i, reduced_train_data[class_inds]))
        print(max_dist)
        plt.gca().add_artist(plt.Circle(i, max_dist, fill=False))
    plt.show()

out = [KMeansModel(n) for n in np.arange(1,16,1)]

And here is one of the figures that I get using the code:

enter image description here

Upvotes: 2

Related Questions