Reputation: 891
I have a dataframe with 5 columns. Now I need to get maximum minutes for each category of the name. I tried the below approach but I need to group by each time for each category of columns. Is there any better way to get the result in one single step instead of grouping multiple times?
val customers = sc.parallelize(Seq(("Alice", "abc","cat1","pt1", 50.00),
("Alice", "abc","cat1","pt1", 45.00),
("Alice", "bcd","cat2","pt1", 55.00),
("Bob", "abc","cat1","pt4", 25.00),
("Bob", "bcd","cat1","pt4", 29.00),
("Bob", "av","cat4","pt4",27.00))).toDF("name","nw","cat","pt","min")
val wSpec1 = Window.partitionBy("name","nw").orderBy("min")
val wSpec2 = Window.partitionBy("name","cat").orderBy("min")
val wSpec3 = Window.partitionBy("name","pt").orderBy("min")
val test1 = customers.withColumn("sumnt",sum(customers("min")).over(wSpec1))
val test2 = customers.withColumn("sumct",sum(customers("min")).over(wSpec2))
val test3 = customers.withColumn("sumpt",sum(customers("min")).over(wSpec3))
val data1 = test1.groupBy("name","nw").agg(max($"sumnt"))
val data2 = test2.groupBy("name","cat").agg(max($"sumct"))
val data3 = test3.groupBy("name","pt").agg(max($"sumpt"))
val res = data1.join(data2,Seq("name"),"left").join(data3,Seq("name"),"left")
+-----+---+-----------+
| name| nw|max(sumnt)|
+-----+---+-----------+
|Alice|bcd| 55.0|
| Bob| av| 27.0|
| Bob|bcd| 29.0|
|Alice|abc| 95.0|
| Bob|abc| 25.0|
+-----+---+-----------+
After doing individual and join, I will get the result as below
+-----+---+----------+----+----------+---+----------+
| name| nw|max(sumnt)| cat|max(sumct)| pt|max(sumpt)|
+-----+---+----------+----+----------+---+----------+
|Alice|bcd| 55.0|cat1| 95.0|pt1| 150.0|
|Alice|bcd| 55.0|cat2| 55.0|pt1| 150.0|
|Alice|abc| 95.0|cat1| 95.0|pt1| 150.0|
|Alice|abc| 95.0|cat2| 55.0|pt1| 150.0|
| Bob| av| 27.0|cat1| 54.0|pt4| 81.0|
| Bob| av| 27.0|cat4| 27.0|pt4| 81.0|
| Bob|bcd| 29.0|cat1| 54.0|pt4| 81.0|
| Bob|bcd| 29.0|cat4| 27.0|pt4| 81.0|
| Bob|abc| 25.0|cat1| 54.0|pt4| 81.0|
| Bob|abc| 25.0|cat4| 27.0|pt4| 81.0|
+-----+---+----------+----+----------+---+----------+
Thanks in advance
Upvotes: 0
Views: 83
Reputation: 307
This should do the job:
var lst = List("nw", "cat", "pt")
lst.foreach {
col =>
var window = Window.partitionBy("name", col)
customers = customers.withColumn("sum_" + col, sum($"min").over(window))
customers = customers.withColumn("max_" + col, max(customers.col("sum_" + col)).over(window))
customers = customers.drop("sum_" + col)
}
customers.show
It will generate the following output.
+-----+---+----+---+----+------+-------+------+
| name| nw| cat| pt| min|max_nw|max_cat|max_pt|
+-----+---+----+---+----+------+-------+------+
| Bob| av|cat4|pt4|27.0| 27.0| 27.0| 81.0|
| Bob|bcd|cat1|pt4|29.0| 29.0| 54.0| 81.0|
| Bob|abc|cat1|pt4|25.0| 25.0| 54.0| 81.0|
|Alice|bcd|cat2|pt1|55.0| 55.0| 55.0| 150.0|
|Alice|abc|cat1|pt1|45.0| 95.0| 95.0| 150.0|
|Alice|abc|cat1|pt1|50.0| 95.0| 95.0| 150.0|
+-----+---+----+---+----+------+-------+------+
Upvotes: 2