Giora Simchoni
Giora Simchoni

Reputation: 3689

GroupBy and aggregate with set intersection

I have a pandas DataFrame with a sets column:

import pandas as pd

df = pd.DataFrame({'group_var': [1,1,2,2], 'sets_var': [set([0, 1]), set([1, 2]), set([3, 4]), set([5, 6, 7])]})
df

   group_var sets_var
0          1      {0, 1}
1          1      {1, 2}
2          2      {3, 4}
3          2   {5, 6, 7}

I wish to groupby the group_var and get the intersection of all corresponding sets of sets_var, like so:

   group_var sets_var
0          1      {1}
1          2      {}

or a Series like so:

   sets_var
1  {1}
2  {}

How would I go about it in elegance? Performance is top priority.

Upvotes: 3

Views: 3669

Answers (1)

cs95
cs95

Reputation: 402493

Use groupby, agg, and reduce using set.intersection.

df.groupby('group_var', as_index=False).agg(lambda x: set.intersection(*x))

   group_var sets_var
0          1      {1}
1          2       {}

If performance is absolutely important, we can try getting rid of the lambda:

from functools import partial, reduce 
import operator

p = partial(reduce, operator.and_)
df.groupby('group_var', as_index=False).agg(p)

   group_var sets_var
0          1      {1}
1          2       {}

However, this only performs a pairwise intersection, so your mileage may vary.


Or, as a Series,

pd.Series({
    k: set.intersection(*g.tolist()) 
    for k, g in df.groupby('group_var')['sets_var']})

1    {1}
2     {}
dtype: object

Upvotes: 5

Related Questions