Klim Yadrintsev
Klim Yadrintsev

Reputation: 111

sklearn.preprocessing.OneHotEncoder and the way to read it

I have been using one-hot encoding for a while now in all pre-processing data pipelines that I have had.

But I have run into an issue now that I am trying to pre-process new data automatically with flask server running a model.

TLDR of what I am trying to do is to search new data for a specific Date, region and type and run a .predict on it.

The problem arises as after I search for a specific data point I have to change the columns from objects to the one-hot encoded ones.

My question is, how do I know which column is for which category inside a feature? As I have around 240 columns after one hot encoding.

Upvotes: 1

Views: 6769

Answers (2)

Sole Galli
Sole Galli

Reputation: 1072

An alternative to sklearn's OneHotEncoder is Feature-engine's OneHotEncoder, which returns clearly named dummy variables:

import pandas as pd
from feature_engine.encoding import OneHotEncoder
X = pd.DataFrame(dict(x1 = [1,2,3,4], x2 = ["a", "a", "b", "c"]))
ohe = OneHotEncoder()
ohe.fit(X)
ohe.transform(X)

The previous code returns the following dataframe:

   x1  x2_a  x2_b  x2_c
0   1     1     0     0
1   2     1     0     0
2   3     0     1     0
3   4     0     0     1

The encoded variables are named with the variable name, then underscore, and then the category, so they are very easy to identify.

I leave the link to Feature-engine's OneHotEncoder for more details.

Upvotes: 1

Corralien
Corralien

Reputation: 120409

IIUC, use get_feature_names_out():

import pandas as pd
from sklearn.preprocessing import OneHotEncoder

df = pd.DataFrame({'A': [0, 1, 2], 'B': [3, 1, 0],
                   'C': [0, 2, 2], 'D': [0, 1, 1]})

ohe = OneHotEncoder()
data = ohe.fit_transform(df)
df1 = pd.DataFrame(data.toarray(), columns=ohe.get_feature_names_out(), dtype=int)

Output:

>>> df
   A  B  C  D
0  0  3  0  0
1  1  1  2  1
2  2  0  2  1


>>> df1
   A_0  A_1  A_2  B_0  B_1  B_3  C_0  C_2  D_0  D_1
0    1    0    0    0    0    1    1    0    1    0
1    0    1    0    0    1    0    0    1    0    1
2    0    0    1    1    0    0    0    1    0    1

>>> pd.Series(ohe.get_feature_names_out()).str.rsplit('_', 1).str[0]
0    A
1    A
2    A
3    B
4    B
5    B
6    C
7    C
8    D
9    D
dtype: object

Upvotes: 2

Related Questions