Mithun Manohar
Mithun Manohar

Reputation: 586

pyspark using window function

I have a dataframe that contains rows which represent an instance of a rating for a particular movie by a user. Each movie can be rated in multiple categories by multiple users. This is resultant dataframe which I created using movie_lens data.

|movie_id|year|categories|
+--------+----+----------+
|     122|1990|    Comedy|
|     122|1990|   Romance|
|     185|1990|    Action|
|     185|1990|     Crime|
|     185|1990|  Thriller|
|     231|1990|    Comedy|
|     292|1990|    Action|
|     292|1990|     Drama|
|     292|1990|    Sci-Fi|
|     292|1990|  Thriller|
|     316|1990|    Action|
|     316|1990| Adventure|
|     316|1990|    Sci-Fi|
|     329|1990|    Action|
|     329|1990| Adventure|
|     329|1990|     Drama|
.
.
.

movie_id is the unique id of the movie, year is the year in which the an user rated the movie, category is one among 12 categories of the movie. Partial File here

I want to find most rated movie in each decade in each category (counting frequency of each movie in each decade in each category)

something like

+-----------------------------------+
| year | category | movie_id | rank |
+-----------------------------------+
| 1990 | Comedy   | 1273     | 1    |
| 1990 | Comedy   | 6547     | 2    |
| 1990 | Comedy   | 8973     | 3    |
.
.
| 1990 | Comedy   | 7483     | 10   |
.
.
| 1990 | Drama    | 1273     | 1    |
| 1990 | Drama    | 6547     | 2    |
| 1990 | Drama    | 8973     | 3    |
.
.
| 1990 | Comedy   | 7483     | 10   |  
.
.
| 2000 | Comedy   | 1273     | 1    |
| 2000 | Comedy   | 6547     | 2    |
.
.

for every decade, top 10 movies in each category 

I understand the pyspark window function needs to be used. This is what I tried

windowSpec = Window.partitionBy(res_agg['year']).orderBy(res_agg['categories'].desc())

final = res_agg.select(res_agg['year'], res_agg['movie_id'], res_agg['categories']).withColumn('rank', func.rank().over(windowSpec))

but it returns some thing like below:

+----+--------+------------------+----+
|year|movie_id|        categories|rank|
+----+--------+------------------+----+
|2000|    8606|(no genres listed)|   1|
|2000|    1587|            Action|   1|
|2000|    1518|            Action|   1|
|2000|    2582|            Action|   1|
|2000|    5460|            Action|   1|
|2000|   27611|            Action|   1|
|2000|   48304|            Action|   1|
|2000|   54995|            Action|   1|
|2000|    4629|            Action|   1|
|2000|   26606|            Action|   1|
|2000|   56775|            Action|   1|
|2000|   62008|            Action|   1|

I am pretty new to pyspark and is stuck here. Can anyone guide me what I am doing wrong.

Upvotes: 1

Views: 358

Answers (1)

Oli
Oli

Reputation: 10406

You're right, you need to use a window, but first, you need to perform a first aggregation to compute the frequencies.

First, let's compute the decade.

df_decade = df.withColumn("decade", concat(substring(col("year"), 0, 3), lit("0")))

Then we compute the frequency by decade, category and movie_id:

agg_df = df_decade\
      .groupBy("decade", "category", "movie_id")\
      .agg(count(col("*")).alias("freq"))

And finally, we define a window partionned by decade and category and select the top 10 using the rank function:

w = Window.partitionBy("decade", "category").orderBy(desc("freq"))
top10 = agg_df.withColumn("r", rank().over(w)).where(col("r") <= 10)

Upvotes: 3

Related Questions