shsh
shsh

Reputation: 727

How To Plot n Furthest Points From Each Centroid KMeans

I am trying to train a kmeans model on the iris dataset in Python.

Is there a way to plot n furthest points from each centroid using kmeans in Python?

Here is a fully working code:

from sklearn import datasets
from sklearn.cluster import KMeans
import numpy as np

# import iris dataset
iris = datasets.load_iris()
X = iris.data[:, 2:5] # use two variables

# plot the two variables to check number of clusters
import matplotlib.pyplot as plt
plt.scatter(X[:, 0], X[:, 1]) 
 
# kmeans
km = KMeans(n_clusters = 2, random_state = 0) # Chose two clusters
y_pred = km.fit_predict(X)

X_dist = kmeans.transform(X) # get distances to each centroid

## Stuck at this point: How to make a function that extracts three points that are furthest from the two centroids
max3IdxArr = []
for label in np.unique(km.labels_):
    X_label_indices = np.where(y_pred == label)[0]
    # max3Idx = X_label_indices[np.argsort(X_dist[:3])] # This part is wrong
    max3Idx = X_label_indices[np.argsort(X_dist[:3])] # This part is wrong
    max3IdxArr.append(max3Idx)
  
max3IdxArr

# plot
plt.scatter(X[:, 0].iloc[max3IdxArr], X[:, 1].iloc[max3IdxArr])

Upvotes: 0

Views: 166

Answers (1)

shyam_gupta
shyam_gupta

Reputation: 328

what you did is np.argsort(X_dist[:3])

which already takes top three values from the unsorted X_dist hence you can

try taking x=np.argsort(x_dist) and

after sorting is done you could then try

x[:3]

feel free to ask, if this isnt working

cheers

Upvotes: 1

Related Questions