Reputation: 12826
I have the following two DataFrames:
l1 = [(['hello','world'],), (['stack','overflow'],), (['hello', 'alice'],), (['sample', 'text'],)]
df1 = spark.createDataFrame(l1)
l2 = [(['big','world'],), (['sample','overflow', 'alice', 'text', 'bob'],), (['hello', 'sample'],)]
df2 = spark.createDataFrame(l2)
df1:
["hello","world"]
["stack","overflow"]
["hello","alice"]
["sample","text"]
df2:
["big","world"]
["sample","overflow","alice","text","bob"]
["hello", "sample"]
For every row in df1, I want to calculate the number of times all the words in the array occur in df2.
For example, the first row in df1 is ["hello","world"]
. Now, I want to check df2 for the intersection of ["hello","world"]
with every row in df2.
| ARRAY | INTERSECTION | LEN(INTERSECTION)|
|["big","world"] |["world"] | 1 |
|["sample","overflow","alice","text","bob"] |[] | 0 |
|["hello","sample"] |["hello"] | 1 |
Now, I want to return the sum(len(interesection))
. Ultimately I want the resulting df1 to look like this:
df1 result:
ARRAY INTERSECTION_TOTAL
| ["hello","world"] | 2 |
| ["stack","overflow"] | 1 |
| ["hello","alice"] | 2 |
| ["sample","text"] | 3 |
How do I solve this?
Upvotes: 4
Views: 2753
Reputation: 35229
I'd focus on avoiding Cartesian product first. I'd try to explode and join
from pyspark.sql.functions import explode, monotonically_increasing_id
df1_ = (df1.toDF("words")
.withColumn("id_1", monotonically_increasing_id())
.select("*", explode("words").alias("word")))
df2_ = (df2.toDF("words")
.withColumn("id_2", monotonically_increasing_id())
.select("id_2", explode("words").alias("word")))
(df1_.join(df2_, "word").groupBy("id_1", "id_2", "words").count()
.groupBy("id_1", "words").sum("count").drop("id_1").show())
+-----------------+----------+
| words|sum(count)|
+-----------------+----------+
| [hello, alice]| 2|
| [sample, text]| 3|
|[stack, overflow]| 1|
| [hello, world]| 2|
+-----------------+----------+
If intermediate values are not needed it could be simplified to:
df1_.join(df2_, "word").groupBy("words").count().show()
+-----------------+-----+
| words|count|
+-----------------+-----+
| [hello, alice]| 2|
| [sample, text]| 3|
|[stack, overflow]| 1|
| [hello, world]| 2|
+-----------------+-----+
and you could omit adding ids.
Upvotes: 1