Alex Kinman
Alex Kinman

Reputation: 2605

How to retrieve the cluster centroids in sic-kit learn's K-means?

I'm using this simple script to cluster data using sci-kit learn

from sklearn.cluster import KMeans
import pandas as pd 
import matplotlib.pyplot as plt
X = pd.read_csv('TestData.csv')
est = KMeans(n_clusters=10)
Y = pd.DataFrame(est.fit_predict(X))
frames = [X,Y]
Out = pd.concat(frames, axis = 1)  

This gives me the following output:

   (feat1) (feat2) (cluster ID) 
0   0.866  1124.182  9
1   2.078  2688.612  1
2   0.000     0.000  0
3   0.000     0.000  0
4   1.038  1344.306  6
5   2.388  3090.338  5
6   0.580   749.456  8
7   1.556  2016.456  2

I want to also display the centroids for each cluster so that the output looks like:

   (feat1) (feat2) (cluster ID) (centroid 1) (centroid 2)  
0   0.866  1124.182  9
1   2.078  2688.612  1
2   0.000     0.000  0
3   0.000     0.000  0
4   1.038  1344.306  6
5   2.388  3090.338  5
6   0.580   749.456  8
7   1.556  2016.456  2

I tried using est.cluster_centers_

But that didn't work.

How can I get the correct cluster centers?

Upvotes: 1

Views: 3414

Answers (1)

unutbu
unutbu

Reputation: 879113

import sklearn.cluster as cluster
import pandas as pd 
import numpy as np 
np.random.seed(2016)

X = pd.DataFrame(np.random.random((100, 2)))
est = cluster.KMeans(n_clusters=10)
Y = pd.DataFrame(est.fit_predict(X), columns=['cluster ID'])
Z = pd.DataFrame(est.cluster_centers_[Y['cluster ID']], 
                 columns=['centroid_x', 'centroid_y'])
result = pd.concat([X, Y, Z], axis=1)  

print(result.head())

yields

          0         1  cluster ID  centroid_x  centroid_y
0  0.896705  0.730239           4    0.900182    0.772332
1  0.783276  0.741652           7    0.705625    0.720808
2  0.462090  0.642565           6    0.279384    0.689603
3  0.224864  0.708547           6    0.279384    0.689603
4  0.747126  0.625107           7    0.705625    0.720808

Note that Z adds a lot of repetitive information to the DataFrame. You would not want to do that if the dataset is large.

Upvotes: 3

Related Questions