Reputation: 19037
Categorical columns are a great way to save memory RAM in pandas, however there are times in which they just slow down things. Specially after you are passed the stage in which you have a big dataframe and now are working in a subset. For example, it doesn't seem to play that well with printing in Jupyter or using libraries like qgrid.
I basically would like to remove all categorical columns from a dataframe to speed up simple things:
Here is an example:
df = pd.DataFrame({"A": ["a", "b", "c", "a"],
"B": ["a", "b", "c", "a"],
"C": [0,3,0,3],
"D": [0.2,0.2,0.3,0.3],
"F": [0,1,2,3]
}
)
df["B"] = df["B"].astype('category')
df["C"] = df["C"].astype('category')
df["D"] = df["D"].astype('category')
Which results in some columns being Categorical (with different types: int, float, str).
df.dtypes
A object
B category
C category
D category
F int64
dtype: object
Ideally something like:
df = df.remove_all_categorical_columns();
That would return the original basic types:
df.dtypes
A object
B object
C int64
D float64
F int64
dtype: object
Upvotes: 3
Views: 6970
Reputation: 19037
You can recover the original data type using df['column'].cat.categories.dtype
. The rest is a matter of going through all the columns using df['column'].astype(df['column'].cat.categories.dtype)
.
The following would work in your example (and hopefully generic enough for other cases):
def uncategorize(col):
if col.dtype.name == 'category':
try:
return col.astype(col.cat.categories.dtype)
except:
# In case there is pd.NA (pandas >= 1.0), Int64 should be used instead of int64
return col.astype(col.cat.categories.dtype.name.title())
else:
return col
df = df.apply(uncategorize, axis=0)
Then, you recover your original dtypes.
df.dtypes
A object
B object
C int64
D float64
F int64
dtype: object
Upvotes: 1
Reputation: 2187
Similar to toto's answer but without df.apply()
.
def recover_dtypes(df):
for col in df.columns:
if df[col].dtype == 'category':
df[col] = df[col].astype(df[col].cat.categories.to_numpy().dtype)
return df
df1 = recover_dtypes(df)
print(df1.dtypes)
>>>
A object
B object
C int64
D float64
F int64
dtype: object
Upvotes: 3