8-Bit Borges
8-Bit Borges

Reputation: 10033

Sklearn - fit, scale and transform

The fit() method in sklearn appears to be serving different purposes in same interface.

When applied to the training set, like so:

model.fit(X_train, y_train)

fit() is used to learn parameters that will later be used on the test set with predict(X_test)


However, there are cases when there is no 'learning' involved with fit(), but only some normalization to transform the data, like so:

min_max_scaler = preprocessing.MinMaxScaler()
min_max_scaler.fit(X_train)

which will simply scale feature values between, say, 0 and 1, to avoid some features with higher variance to have a disproportional influence on the model.


To make things even less intuitive, sometimes the fit() method that scales (and already appears to be transforming) needs to be followed by further transform() method, before being called again with the fit() that actually learns and builds the model, like so:

X_train2 = min_max_scaler.transform(X_train)
X_test2 = min_max_scaler.transform(X_test)

# the model being used
knn = KNeighborsClassifier(n_neighbors=3,metric="euclidean")
# learn parameters
knn.fit(X_train2, y_train)
# predict
y_pred = knn.predict(X_test2)

Could someone please clarify the use, or multiple uses, of fit(), as well as the difference of scaling and transforming the data?

Upvotes: 4

Views: 2174

Answers (2)

krisograbek
krisograbek

Reputation: 1782

In scikit-learn there are 3 classes that share interface: Estimators, Transformers and Predictors

Estimators have fit() function, which serves always the same purpose. It estimates parameters based on the dataset.

Transformers have transform() function. It returns the transformed dataset. Some Estimators are also Transformers, e.g. MinMaxScaler()

Predictors have predict() function, which returns predictions on new instances, e.g. KNeighborsClassifier()

Both MinMaxScaler() and KNeighborClassifier() contain fit() method, because they share interface of an Estimator.

However, there are cases when there is no 'learning' involved with fit()

There is 'learning' involved. Transformer, MinMaxScaler() has to 'learn' min and max values for each numerical feature. When you call min_max_scaler.fit(X_train) your scaler estimates values for each numerical column in your train set. min_max_scaler.transform(X_train) scales your train set based on the estimations. min_max_scaler.transform(X_test) scales the test set with the estimations learned for train set. This is important to scale both train and test set with the same estimations.

For further reading, you can check this: https://arxiv.org/abs/1309.0238

Upvotes: 2

Antoine Dubuis
Antoine Dubuis

Reputation: 5304

fit() function provides a common interface that is shared among all scikit-learn objects.

This function takes as argument X ( and sometime y array to compute the object's statistics. For example, calling fit on a MinMaxScaler transformer will compute its statistics (data_min_, data_max_, data_range_...

Therefore we should see the fit() function as a method that compute the necessary statistics of an object.

This commons interface is really helpful as it allows to combine transformer and estimators together using a Pipeline. This allows to compute and predict all steps in one go as follows:

from sklearn.pipeline import make_pipeline
from sklearn.datasets import make_classification

from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors


X, y = make_classification(n_samples=1000)
model = make_pipeline(MinMaxScaler(), NearestNeighbors())
model.fit(X, y)

This offers also the possibility to serialize the whole model into one single object.

Without this composition module, I can agree with you that it is not very practically to work with independent transformer and estimator.

Upvotes: 3

Related Questions