dzieciou
dzieciou

Reputation: 4514

Why ColumnTransformer does not call fit on its transformers?

I have defined data for fitting with one categorical feature "sex":

data = pd.DataFrame({
    'age': [25,19, 17],
    'sex': ['female', 'male', 'female'],
    'won_lottery': [False, True, False]
})
X = data[['age', 'sex']]
y = data['won_lottery']

and pipeline for transforming categorical features:

ohe = OneHotEncoder(handle_unknown='ignore')
cat_transformers = Pipeline([
    ('onehot', ohe)
])

When fitting cat_transformers with data directly

cat_transformers.fit(X[['sex']], y)
print(ohe.get_feature_names())

I am able to get names of output features created by OneHotEncoder instance:

['x0_female' 'x0_male']    

However, if I encapsulate cat_transformers into ColumnTransformer:

preprocessor = ColumnTransformer(
    transformers=[
        ('cat', cat_transformers, ['sex'])
    ]
)
preprocessor.fit(X, y)
print(ohe.get_feature_names())

it fails with

sklearn.exceptions.NotFittedError: This OneHotEncoder instance is not fitted yet. 
  Call 'fit' with appropriate arguments before using this method.

I would expect that calling fit() on ColumnTransformer causes calling fit() on all its transformers.

Why it does not work this way?

Upvotes: 10

Views: 7149

Answers (3)

T L
T L

Reputation: 11

Adding to what Gonzalo Garcia said. I think what's happening is that the transformers parameter from the __init__ method is being tucked away in the transformers attribute of the class. Looking at the ColumnTransformer documentation, it calls out transformers_ as the actual attribute name we are supposed to use to access the transformer objects.

It might have made sense to tuck the parameters away in hidden attributes to avoid the confusion, but that's for a future version I suppose!

Upvotes: 0

Gonzalo Garcia
Gonzalo Garcia

Reputation: 6612

Not an answer but I came to see how OHE names could be obtained from a ColumnTransformer

(Thanks @skibee). You have to use transformer_ instead.

# assuming that your OHE is in that index position 
preprocessor.transformers_[2][1][1].get_feature_names()

Upvotes: 0

dzieciou
dzieciou

Reputation: 4514

Ok, I understand it now. I was fitting one instance of OneHotEncoder and checking features on another instance:

print(id(ohe))
print(id(preprocessor.named_transformers_['cat'].named_steps['onehot']))

2757198591872
2755226729104

It looks like ColumnTranformer clones its transformers before fitting.

Upvotes: 9

Related Questions