Reputation: 350
I have a dataframe containing maybe ~200k words and phrases. Many of the phrases are duplicative (e.g., the word "adventure" appears thousands of times). I'd like to get a count of each word, and then dedupe. My attempts at doing so take a really long time--longer than anything else in the entire script--and doing anything with the resulting dataframe after I've gotten my counts also takes forever.
My initial input looks like this:
Keyword_Type Keyword Clean_Keyword
Not Geography cat cat
Not Geography cat cat
Not Geography cat cat
Not Geography cats cat
Not Geography cave cave
Not Geography celebrity celebrity
Not Geography celebrities celebrity
I'm looking for an output like this, which counts all of the times that a given word appears in the dataframe.
Keyword_Type Keyword Clean_Keyword Frequency
Not Geography cat cat 4
Not Geography cat cat 4
Not Geography cat cat 4
Not Geography cats cat 4
Not Geography cave cave 1
Not Geography celebrity celebrity 2
Not Geography celebrities celebrity 2
The code snippet I'm currently using is as follows:
w = Window.partitionBy("Clean_Keyword")
key_count = key_lemma.select("Keyword_Type","Keyword", "Clean_Keyword", count("Clean_Keyword").over(w).alias("Frequency")).dropDuplicates()
I also tried a groupby
-count
on the Clean Keyword column and joining back to the original key_count df, but this also takes a disgustingly large amount of time. The functions on this same dataset before counting and deduping run in a second or two or less, and any subsequent functions I run on the resulting df are faster on my computer using straight Python than PySpark running on a decent-sized cluster. So, needless to say, something I'm doing is definitely not optimal...
Upvotes: 0
Views: 1758
Reputation: 10076
Your window function solution is going to be the most efficient, a join implies 2 sorts whereas a window only implies one. You may be able to optimize by doing a groupBy
before windowing instead of a dropDuplicates
after:
import pyspark.sql.functions as psf
from pyspark.sql import Window
w = Window.partitionBy("Clean_Keyword")
key_count = key_lemma\
.groupBy(key_lemma.columns)\
.agg(psf.count("*").alias("Frequency"))\
.withColumn("Frequency", psf.sum("Frequency").over(w))
key_count.show()
+-------------+-----------+-------------+---------+
| Keyword_Type| Keyword|Clean_Keyword|Frequency|
+-------------+-----------+-------------+---------+
|Not Geography| cave| cave| 1|
|Not Geography| cat| cat| 4|
|Not Geography| cats| cat| 4|
|Not Geography| celebrity| celebrity| 2|
|Not Geography|celebrities| celebrity| 2|
+-------------+-----------+-------------+---------+
This will be more efficient especially if you have a lot of lines but not so many distinct keys (most Keywords
are equal to their Clean_Keyword
)
Upvotes: 2