Python Spark
Python Spark

Reputation: 303

Pandas multiple column intersection

I have a data frame as follows:

data={'NAME':['JOHN','MARY','CHARLIE'],
  'A':[[1,2,3],[2,3,4],[3,4,5]],
  'B':[[2,3,4],[3,4,5],[4,5,6]],
    'C':[[2,4],[3,4],[6,7]]  }
df=pd.DataFrame(data)
df=df[['NAME','A','B','C']]
NAME          A          B            C
0   JOHN    [1, 2, 3]   [2, 3, 4]   [2, 4]
1   MARY    [2, 3, 4]   [3, 4, 5]   [3, 4]
2   CHARLIE [3, 4, 5]   [4, 5, 6]   [6, 7]

I need intersection of all columns A, B, C.

I tried following code but did not work:

df['D']=list(set(df['A'])&set(df['B'])&set(df['C']))

The output required is as follows:

    NAME            A         B         C       D
0   JOHN    [1, 2, 3]   [2, 3, 4]   [2, 4]  [2]
1   MARY    [2, 3, 4]   [3, 4, 5]   [3, 4]  [3, 4]
2   CHARLIE [3, 4, 5]   [4, 5, 6]   [6, 7]  []

Upvotes: 2

Views: 8490

Answers (3)

Mohamed Ali JAMAOUI
Mohamed Ali JAMAOUI

Reputation: 14689

option 1:

The intersection syntax set(A)&set(B).. is correct but you need to tweak it a bit to be applicable on a dataframe as follows:

df.assign(D=df.transform(
     lambda x: list(set(x.A)&set(x.B)&set(x.C)),
     axis=1))

You can proceed as follows:

option 2:

df.assign(D=df.transform(
    lambda x: list(set(x.A).intersection(set(x.B)).intersection(set(x.C))),
    axis=1))

or

df.assign(D=df.apply(
    lambda x: list(set(x.A).intersection(set(x.B)).intersection(set(x.C))),
    axis=1))

option 3:

df.assign(D=df.transform(
    lambda x: list(reduce(set.intersection, map(set,x.tolist()[1:]))),
    axis=1))

What this does is:

  • Get the intersection by chain using set(x.A).intersection(set(x.B)).. for each row
  • Convert the result to list
  • Do that for each row in the dataframe

Execution details:

In [76]: df.assign(D=df.transform(
    ...:     lambda x: list(set(x.A).intersection(set(x.B)).intersection(set(x.C))),
    ...:     axis=1))
Out[76]: 
      NAME          A          B       C       D
0     JOHN  [1, 2, 3]  [2, 3, 4]  [2, 4]     [2]
1     MARY  [2, 3, 4]  [3, 4, 5]  [3, 4]  [3, 4]
2  CHARLIE  [3, 4, 5]  [4, 5, 6]  [6, 7]      []

Upvotes: 1

BENY
BENY

Reputation: 323306

df[['A','B','C']].apply(lambda x : list(set.intersection(*map(set,list(x)))),axis=1 )

Out[1192]: 
0       [2]
1    [3, 4]
2        []
dtype: object

Upvotes: 2

IanS
IanS

Reputation: 16251

Using the answer here, apply it to the dataframe row by row:

df[['A', 'B', 'C']].apply(
    lambda row: list(set.intersection(*[set(row[col]) for col in row.index])), 
    axis=1
)

Note that when applying a function by row, the row's index values are the original dataframe's columns.

Upvotes: 3

Related Questions