Myccha
Myccha

Reputation: 1018

Cast a list column into dummy columns in Python Polars?

I have a very large data frame where there is a column that is a list of numbers representing category membership.

Here is a dummy version

import pandas as pd 
import numpy as np 

segments = [str(i) for i in range(1_000)]

# My real data is ~500m rows
nums = np.random.choice(segments, (100_000,10))

df = pd.DataFrame({'segments': [','.join(n) for n in nums]})
userId segments
0 885,106,49,138,295,254,26,460,0,844
1 908,709,454,966,151,922,666,886,65,708
2 664,713,272,241,301,498,630,834,702,289
3 60,880,906,471,437,383,878,369,556,876
4 817,183,365,171,23,484,934,476,273,230
... ...

Note that there is a known list of segments (0-999 in the example)

I want to cast this into dummy columns indicating membership to each segment.

I found a few ways of doing this:

In pandas:

df_one_hot_encoded = (df['segments']
    .str.split(',')
    .explode()
    .reset_index()
    .assign(__one__=1)
    .pivot_table(index='index', columns='segments', values='__one__', fill_value=0)
)

(takes 8 seconds on a 100k row sample)

And polars

df_ans = (df2
          .with_columns(
              pl.int_range(pl.len()).alias('row_index'), 
              pl.col('segments').str.split(','), 
              pl.lit(1).alias('__one__')
          )
          .explode('segments')
          .pivot(on='segments', index='row_index', values='__one__', aggregate_function='first')
          .fill_null(0)
     )
df_one_hot_encoded = df_ans.to_pandas()

(takes 1.5 seconds inclusive of the conversion to and from pandas, .9s without)

However, I hear .pivot is not efficient, and that it does not work well with lazy frames.

I tried other solutions in polars, but they were much slower:

_ = df2.lazy().with_columns(**{segment: pl.col('segments').str.contains(segment) for segment in segments}).collect()

(2 seconds)

(df2
 .with_columns(
     pl.int_range(pl.len()).alias('row_index'), 
     pl.col('segments').str.split(',')
 )
 .explode('segments')
 .to_dummies(columns=['segments'])
 .group_by('row_index')
 .sum()
)

(4 seconds)

Does anyone know a better solution than the .9s pivot?

Upvotes: 3

Views: 849

Answers (1)

Dean MacGregor
Dean MacGregor

Reputation: 18416

This approach ends up being slower than the pivot but it's a got a different trick so I'll include it.

df2=pl.from_pandas(df)
df2_ans=(df2.with_row_index('userId').with_columns(pl.col('segments').str.split(',')).explode('segments') \
    .with_columns(pl.when(pl.col('segments')==pl.lit(str(i))).then(pl.lit(1,pl.Int32)).otherwise(pl.lit(0,pl.Int32)).alias(str(i)) for i in range(1000)) \
    .group_by('userId')).agg(pl.exclude('segments').sum())
df_one_hot_encoded = df2_ans.to_pandas()   

A couple of other observations. I'm not sure if you checked the output of your str.contains method but I would think that wouldn't work because, for example, 15 is contained within 154 when looking at strings.

The other thing, which I guess is just a preference, is the with_row_index syntax vs the pl.int_range. I don't think the performance of either is better (at least not significantly so).

I tried a couple other things that were also worse including not doing the explode and just doing is_in but that was slower. I tried using bools instead of 1s and 0s and then aggregating with any but that was slower.

Upvotes: 1

Related Questions