Rakesh Adhikesavan
Rakesh Adhikesavan

Reputation: 12826

PySpark: Compare array values in one dataFrame with array values in another dataFrame to get the intersection

I have the following two DataFrames:

l1 = [(['hello','world'],), (['stack','overflow'],), (['hello', 'alice'],), (['sample', 'text'],)]
df1 = spark.createDataFrame(l1)

l2 = [(['big','world'],), (['sample','overflow', 'alice', 'text', 'bob'],), (['hello', 'sample'],)]
df2 = spark.createDataFrame(l2) 

df1:

["hello","world"]
["stack","overflow"]
["hello","alice"]
["sample","text"]

df2:

["big","world"]
["sample","overflow","alice","text","bob"]
["hello", "sample"]

For every row in df1, I want to calculate the number of times all the words in the array occur in df2.

For example, the first row in df1 is ["hello","world"]. Now, I want to check df2 for the intersection of ["hello","world"] with every row in df2.

|                  ARRAY                    | INTERSECTION | LEN(INTERSECTION)| 
|["big","world"]                            |["world"]     | 1                |
|["sample","overflow","alice","text","bob"] |[]            | 0                |   
|["hello","sample"]                         |["hello"]     | 1                |

Now, I want to return the sum(len(interesection)). Ultimately I want the resulting df1 to look like this:

df1 result:

       ARRAY               INTERSECTION_TOTAL
| ["hello","world"]    |      2                 |
| ["stack","overflow"] |      1                 |
| ["hello","alice"]    |      2                 |
| ["sample","text"]    |      3                 |

How do I solve this?

Upvotes: 4

Views: 2753

Answers (1)

Alper t. Turker
Alper t. Turker

Reputation: 35229

I'd focus on avoiding Cartesian product first. I'd try to explode and join

from pyspark.sql.functions import explode, monotonically_increasing_id

df1_ = (df1.toDF("words")
  .withColumn("id_1", monotonically_increasing_id())
  .select("*", explode("words").alias("word")))

df2_ = (df2.toDF("words")
    .withColumn("id_2", monotonically_increasing_id())
    .select("id_2", explode("words").alias("word")))

(df1_.join(df2_, "word").groupBy("id_1", "id_2", "words").count()
    .groupBy("id_1", "words").sum("count").drop("id_1").show())
+-----------------+----------+                                                  
|            words|sum(count)|
+-----------------+----------+
|   [hello, alice]|         2|
|   [sample, text]|         3|
|[stack, overflow]|         1|
|   [hello, world]|         2|
+-----------------+----------+

If intermediate values are not needed it could be simplified to:

df1_.join(df2_, "word").groupBy("words").count().show()
+-----------------+-----+                                                       
|            words|count|
+-----------------+-----+
|   [hello, alice]|    2|
|   [sample, text]|    3|
|[stack, overflow]|    1|
|   [hello, world]|    2|
+-----------------+-----+

and you could omit adding ids.

Upvotes: 1

Related Questions