amal
amal

Reputation: 11

intersect operation between two dataframes by using sql in pyspark code

I have two data frames

first dataframe:

+---+---+----------------+
|  p|  o| collect_list(s)|
+---+---+----------------+
|  T| V2|[c1, c5]        |
|  T| V1|[c2, c3, c4, c6]|
+---+---+----------------+

second dataframe:

+---+---+--------------------+
|  p|  o|     collect_list(s)|
+---+---+--------------------+
|  A| V3|[c1, c2, c3, c4]    |
|  B| V3|[c1, c2, c3, c5, c6]|
+---+---+--------------------+

How can we do intersect operation between above dataframes based on collect_list column? The result should be another dataframe that join between items if the length of intersect operation greater than minimum support 2 as following:

+----------------------------+
|  2-Itemset    |TID         |
+----------------------------+
|[(T,V2),(B,V3)]|[c1, c5]    |
|[(T,V1),(A,V3)]|[c2, c3,c4] |
|[(T,V1),(B,V3)]|[c2,c3,c6]  |
+----------------------------+ 

Upvotes: 0

Views: 377

Answers (1)

pault
pault

Reputation: 43504

I think you need to do a crossJoin() and some use some udfs to accomplish this. Consider the following example:

Create example data

import pyspark.sql.functions as f

data1 = [
    ('T', 'V2', ['c1', 'c5']),
    ('T', 'V1', ['c2', 'c3', 'c4', 'c6'])
]
df1 = sqlCtx.createDataFrame(data1, ["p", "o", "c"])

data2 = [
    ('A', 'V3', ['c1', 'c2', 'c3', 'c4'] ),
    ('B', 'V3', ['c1', 'c2', 'c3', 'c5', 'c6'])
]
df2 = sqlCtx.createDataFrame(data2, ["p", "o", "c"])

Define some UDFs

intersection_udf = f.udf(lambda u, v: list(set(u) & set(v)), ArrayType(StringType()))
intersection_length_udf = f.udf(lambda u, v: len(set(u) & set(v)), IntegerType())

Cartesian Product

df1.alias("l")\
    .crossJoin(df2.alias("r"))\
    .select(
        f.col('l.p').alias('lp'),
        f.col('l.o').alias('lo'),
        f.col('r.p').alias('rp'),
        f.col('r.o').alias('ro'),
        intersection_udf(f.col('l.c'), f.col('r.c')).alias('TID'),
        intersection_length_udf(f.col('l.c'), f.col('r.c')).alias('len')
    )\
    .where(f.col('len') > 1)\
    .select(
        f.struct(f.struct('lp', 'lo'), f.struct('rp', 'ro')).alias('2-Itemset'), 
        'TID'
    )\
    .show()

Output

#+---------------+------------+
#|      2-Itemset|         TID|
#+---------------+------------+
#|[[T,V2],[B,V3]]|    [c1, c5]|
#|[[T,V1],[A,V3]]|[c3, c2, c4]|
#|[[T,V1],[B,V3]]|[c3, c2, c6]|
#+---------------+------------+

Upvotes: 1

Related Questions