Smeet Patel
Smeet Patel

Reputation: 137

What is the use of predict() method in kmeans implementation of scikit learn?

Can someone explain what is the use of predict() method in kmeans implementation of scikit learn? The official documentation states its use as:

Predict the closest cluster each sample in X belongs to.

But I can get the cluster number/label for each sample of input set X by training the model on fit_transform() method also. So what is the use of predict() method? Is it supposed to point out closest cluster for the unseen data? If yes, then how do you handle a new data point if you perform dimensionality reduction measure such as SVD?

Here's a similar question but I still don't think it really helps.

Upvotes: 9

Views: 31011

Answers (1)

MB-F
MB-F

Reputation: 23637

what is the use of predict() method? Is it supposed to point out closest cluster for the unseen data?

Yes, exactly.

then how do you handle a new data point if you perform dimensionality reduction measure such as SVD?

You apply the same dimensionality reduction method to the unseen data before passing it to .predict(). Here is a typical workflow:

# prerequisites:
#    x_train: training data
#    x_test: "unseen" testing data
#    km: initialized `KMeans()` instance
#    dr: initialized dimensionality reduction instance (such as `TruncatedSVD()`)    

# fitting
x_dr = dr.fit_transform(x_train)
y = km.fit_predict(x_dr)  

# ...

# working with unseen data (models have been fitted before)
x_dr = dr.transform(x_test)
y = km.predict(x_dr)

# ...

Actually, methods such as fit_transform and fit_predict are there for convenience. y = km.fit_predict(x) is equivalent to y = km.fit(x).predict(x).

I think it's easier to see what's going on if we write the fitting part as follows:

# fitting
dr.fit(x_train)
x_dr = dr.transform(x_train)

km.fit(x_dr)
y = km.predict(x_dr)

Except for the call to .fit() the models used equally during fitting and with unseen data.

Summary:

  • The purpose of .fit() is to train the model with data.
  • The purpose of .predict() or .transform() is to apply a trained model to data.
  • If you want to fit the model and apply it to the same data during training, there are .fit_predict() or .fit_transform() for convenience.
  • When chaining multiple models (such as dimensionality reduction and clustering) apply them in the same order during fitting and testing.

Upvotes: 12

Related Questions