giacrava
giacrava

Reputation: 169

ColumnTransformer inside a Pipeline

I'm building a pipeline in scikit-learn. I have to do different transformations with different features, and then standardize them all. So I built a ColumnTransformer with a custom transformer for each set of columns:

transformation_pipeline = ColumnTransformer([
    ('adoption', TransformAdoptionFeatures, features_adoption),
    ('census', TransformCensusFeaturesRegr, features_census),
    ('climate', TransformClimateFeatures, features_climate),
    ('soil', TransformSoilFeatures, features_soil),
    ('economic', TransformEconomicFeatures, features_economic)
],
    remainder='drop')

Then, since I'd like to create two different pipelines to both standardize and normalize my features, I was thinking of combining transformation_pipeline and the scaler in one pipeline:

full_pipeline_stand = Pipeline([
    ('transformation', transformation_pipeline()),
    ('scaling', StandardScaler())
])

However, I get the following error:

TypeError: 'ColumnTransformer' object is not callable

Is there a way to do this without building a separate pipeline for each set of columns (combining the custom transformer and the scaler)? That is obviously working but is just looks like useless repetition to me... Thanks!

Upvotes: 2

Views: 2465

Answers (1)

giacrava
giacrava

Reputation: 169

I spot my error, I was switching the instantiation of the classes: the custom transformers have to be instantiated inside the ColumnTransformer, while the ColumnTransformer does not have to be instantiated inside the pipeline.

The correct code is the following:

transformation_pipeline = ColumnTransformer([
    ('adoption', TransformAdoptionFeatures(), features_adoption),
    ('census', TransformCensusFeaturesRegr(), features_census),
    ('climate', TransformClimateFeatures(), features_climate),
    ('soil', TransformSoilFeatures(), features_soil),
    ('economic', TransformEconomicFeatures(), features_economic)
],
    remainder='drop')

full_pipeline_stand = Pipeline([
    ('transformation', transformation_pipeline),
    ('scaling', StandardScaler())
])

Upvotes: 2

Related Questions