Rachael
Rachael

Reputation: 315

How can I get a representative point of a GMM cluster?

I have clustered my data (75000, 3) using sklearn Gaussian mixture model algorithm (GMM). I have 4 clusters. Each point of my data represents a molecular structure. Now I would like to get the most representative molecular structure of each cluster which I understand is the centroid of the cluster. So far, I have tried to locate the point (structure) that is right in the centre of the cluster using gmm.means_ attribute, however that exact point does not correspond to any structure (I used numpy.where). I would need to obtain the coordinates of the closest structure to the centroid, but I have not found the function to do that in the documentation of the module (http://scikit-learn.org/stable/modules/generated/sklearn.mixture.GaussianMixture.html). How can I get a representative structure of each cluster?

Thanks a lot for your help, any suggestion will be appreciated.

((As this is a generic question I haven't found necessary to add the code used for the clustering or any data, please let me know if it is necessary))

Upvotes: 5

Views: 9396

Answers (2)

David Dale
David Dale

Reputation: 11434

For each cluster, you can measure its corresponding density for each training point, and choose the point with the maximal density to represent its cluster:

This code can serve as an example:

import numpy as np
import matplotlib.pyplot as plt
import scipy.stats
from sklearn import mixture

n_samples = 100
C = np.array([[0.8, -0.1], [0.2, 0.4]])

X = np.r_[np.dot(np.random.randn(n_samples, 2), C),
         np.random.randn(n_samples, 2) + np.array([-2, 1]), 
         np.random.randn(n_samples, 2) + np.array([1, -3])]

gmm = mixture.GaussianMixture(n_components=3, covariance_type='full').fit(X)

plt.scatter(X[:,0], X[:, 1], s = 1)

centers = np.empty(shape=(gmm.n_components, X.shape[1]))
for i in range(gmm.n_components):
    density = scipy.stats.multivariate_normal(cov=gmm.covariances_[i], mean=gmm.means_[i]).logpdf(X)
    centers[i, :] = X[np.argmax(density)]
plt.scatter(centers[:, 0], centers[:, 1], s=20)
plt.show()

It would draw the centers as orange dots:

enter image description here

Upvotes: 14

Has QUIT--Anony-Mousse
Has QUIT--Anony-Mousse

Reputation: 77485

Find the point with the smallest Mahalanobis distance to the cluster center.

Because GMM uses Mahalanobis distance to assign points. By the GMM model, this is the point with the highest probability of belonging to this cluster.

You have all you need to compute this: cluster means_ and covariances_.

Upvotes: 0

Related Questions