Reputation: 883
I have found the following code for selecting n rows from dataframe grouped by unique_id.
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.row_number
val window = Window.partitionBy("userId").orderBy($"rating".desc)
dataframe.withColumn("r", row_number.over(window)).where($"r" <= n)
I have tried the following:
from pyspark.sql.functions import row_number, desc
from pyspark.sql.window import Window
w = Window.partitionBy(post_tags.EntityID).orderBy(post_tags.Weight)
newdata=post_tags.withColumn("r", row_number.over(w)).where("r" <= 3)
I get the following error:
AttributeError: 'function' object has no attribute 'over'
Please help me on the same.
Upvotes: 1
Views: 5045
Reputation: 883
I found the answer to this:
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, col
window = Window.partitionBy(df['user_id']).orderBy(df['score'].desc())
df.select('*', rank().over(window).alias('rank'))
.filter(col('rank') <= 2)
.show()
Credits to @mtoto for his answer https://stackoverflow.com/a/38398563/5165377
Upvotes: 1