Reputation: 317
I have a dataframe as below. Trying to figure out how to calculate the percentage of each colour per product, and generate something like the expected output. I tried to use a window w= Window.partitionBy["prod_name","colour"]
and do a count df.withColumn("cnt", F.count("colour").over("w"))
but that's far from correct as that only count the number of the "colour" column. Could someone please help? Many thanks.
Input:
prod_name | colour
---------------
A | blue
A | blue
A | yellow
B | green
B | blue
C | red
Output:
prod_name | colour | percentage
----------------------
A | blue. | 0.67. ---- as blue account for 2/3 of product A
A | yellow.| 0.33 --- yellow account for 1/3 of product A
B | green |. 0.5 --- green account for 1/2 of product B
B | blue. | 0.5 --- blue account for 1/2 of product B
C | red. | 1 --- red account for 100% of product C
Upvotes: 2
Views: 643
Reputation: 150
In pyspark, you can do like this if you want to use window function:-
df=df.withColumn("cntProduct",count("product_name").over(Window.partitionBy("product_name")))
df=df.withColumn("cntProduct_colour",count("colour").over(Window.partitionBy(["product_name","colour"])))
df=df.withColumn("required",df.cntProduct_colour/df.cntProduct)
df=df.select("product_name","colour","required").distinct()
df.show()
Upvotes: 4
Reputation: 1751
The following solution based on Scala snippet might help you,
Generate total count for each prod_name.
val productCountDF = inputDF
.groupBy("prod_name").agg(count("*") as "total_products")
Join the above DF with the original Input DF
// Assuming the data footprint of productCountDF is small enough to fit for broadcast join in order to avoid shuffle.
val newDF = inputDF
.join(broadcast(productCountDF), Seq("prod_name"))
The above join will get the total product count for each prod_name.
val finalDF = newDF
.groupBy("product_name", "total_products", "colour")
.agg(count("*") as "total_products_per_colour")
finalDF
.withColumn("percentage", col("total_products_per_colour")/ col("total_products"))
.drop("total_products_per_colour", "total_products")
.show(false)
Upvotes: 2