avpenn
avpenn

Reputation: 55

Semi-supervised Gaussian mixture model clustering in Python

I have images that I am segmenting using a gaussian mixture model from scikit-learn. Some images are labeled, so I have a good bit of prior information that I would like to use. I would like to run a semi-supervised training of a mixture model, by providing some of the cluster assignments ahead of time.

From the Matlab documentation, I can see that Matlab allows initial values to be set. Are there any python libraries, especially scikit-learn approaches that would allow this?

Upvotes: 3

Views: 2399

Answers (1)

lightalchemist
lightalchemist

Reputation: 10211

The standard GMM does not work in a semi-supervised fashion. The initial values you mentioned is likely the initial values for the mean vectors and covariance matrices for the gaussians which will be updated by the EM algorithm.

A simple hack will be to group your labeled data based on their labels and individually estimate mean vectors and covariance matrices for them and pass these as the initial values to your MATLAB function (scikit-learn does not allow this as far as I'm aware). Hopefully this will position your Gaussians at the "correct locations". The EM algorithm will then take it from there to adjust these parameters.

The downside of this hack is that it does not guarantee that it will respect your true label assignment, hence even if a data point is assigned a particular cluster label, there is a chance that it might be re-assigned to another cluster. Also, noise in your feature vectors or labels could also cause your initial Gaussians to cover a much larger region than it is suppose to, hence wrecking havoc on the EM algorithm. Also, if you do not have sufficient data points for a particular cluster, your estimated covariance matrices might be singular, hence breaking this trick altogether.

Unless it is a must for you to use GMM to cluster your data (for e.g., you know for sure that gaussians model your data well), then perhaps you can just try the semi-supervised methods in scikit-learn . These will propagate the labels based on feature similarities to your other data point. However, I doubt this can handle large dataset as it requires the graph laplacian matrix to be built from pairs of samples, unless there is some special implementation trick to handle this in scikit-learn.

Upvotes: 2

Related Questions