Reputation: 4514
I have defined data for fitting with one categorical feature "sex":
data = pd.DataFrame({
'age': [25,19, 17],
'sex': ['female', 'male', 'female'],
'won_lottery': [False, True, False]
})
X = data[['age', 'sex']]
y = data['won_lottery']
and pipeline for transforming categorical features:
ohe = OneHotEncoder(handle_unknown='ignore')
cat_transformers = Pipeline([
('onehot', ohe)
])
When fitting cat_transformers
with data directly
cat_transformers.fit(X[['sex']], y)
print(ohe.get_feature_names())
I am able to get names of output features created by OneHotEncoder
instance:
['x0_female' 'x0_male']
However, if I encapsulate cat_transformers
into ColumnTransformer
:
preprocessor = ColumnTransformer(
transformers=[
('cat', cat_transformers, ['sex'])
]
)
preprocessor.fit(X, y)
print(ohe.get_feature_names())
it fails with
sklearn.exceptions.NotFittedError: This OneHotEncoder instance is not fitted yet.
Call 'fit' with appropriate arguments before using this method.
I would expect that calling fit()
on ColumnTransformer
causes calling fit()
on all its transformers.
Why it does not work this way?
Upvotes: 10
Views: 7149
Reputation: 11
Adding to what Gonzalo Garcia said. I think what's happening is that the transformers
parameter from the __init__
method is being tucked away in the transformers
attribute of the class. Looking at the ColumnTransformer documentation, it calls out transformers_
as the actual attribute name we are supposed to use to access the transformer objects.
It might have made sense to tuck the parameters away in hidden attributes to avoid the confusion, but that's for a future version I suppose!
Upvotes: 0
Reputation: 6612
Not an answer but I came to see how OHE names could be obtained from a ColumnTransformer
(Thanks @skibee). You have to use transformer_
instead.
# assuming that your OHE is in that index position
preprocessor.transformers_[2][1][1].get_feature_names()
Upvotes: 0
Reputation: 4514
Ok, I understand it now. I was fitting one instance of OneHotEncoder
and checking features on another instance:
print(id(ohe))
print(id(preprocessor.named_transformers_['cat'].named_steps['onehot']))
2757198591872
2755226729104
It looks like ColumnTranformer
clones its transformers before fitting.
Upvotes: 9