Babu
Babu

Reputation: 891

Getting maximum mins for each category of column in a dataframe using scala

I have a dataframe with 5 columns. Now I need to get maximum minutes for each category of the name. I tried the below approach but I need to group by each time for each category of columns. Is there any better way to get the result in one single step instead of grouping multiple times?

 val customers = sc.parallelize(Seq(("Alice", "abc","cat1","pt1", 50.00),
                                ("Alice", "abc","cat1","pt1", 45.00),
                                ("Alice", "bcd","cat2","pt1", 55.00),
                                ("Bob", "abc","cat1","pt4", 25.00),
                                ("Bob", "bcd","cat1","pt4", 29.00),
                                ("Bob", "av","cat4","pt4",27.00))).toDF("name","nw","cat","pt","min")


val wSpec1 = Window.partitionBy("name","nw").orderBy("min")
val wSpec2 = Window.partitionBy("name","cat").orderBy("min")    
val wSpec3 = Window.partitionBy("name","pt").orderBy("min") 

val test1 = customers.withColumn("sumnt",sum(customers("min")).over(wSpec1))
val test2 = customers.withColumn("sumct",sum(customers("min")).over(wSpec2))
val test3 = customers.withColumn("sumpt",sum(customers("min")).over(wSpec3))

val data1 = test1.groupBy("name","nw").agg(max($"sumnt"))
val data2 = test2.groupBy("name","cat").agg(max($"sumct"))
val data3 = test3.groupBy("name","pt").agg(max($"sumpt"))


val res =     data1.join(data2,Seq("name"),"left").join(data3,Seq("name"),"left")

+-----+---+-----------+
| name| nw|max(sumnt)|
+-----+---+-----------+
|Alice|bcd|       55.0|
|  Bob| av|       27.0|
|  Bob|bcd|       29.0|
|Alice|abc|       95.0|
|  Bob|abc|       25.0|
+-----+---+-----------+

After doing individual and join, I will get the result as below

+-----+---+----------+----+----------+---+----------+
| name| nw|max(sumnt)| cat|max(sumct)| pt|max(sumpt)|
+-----+---+----------+----+----------+---+----------+
|Alice|bcd|      55.0|cat1|      95.0|pt1|     150.0|
|Alice|bcd|      55.0|cat2|      55.0|pt1|     150.0|
|Alice|abc|      95.0|cat1|      95.0|pt1|     150.0|
|Alice|abc|      95.0|cat2|      55.0|pt1|     150.0|
|  Bob| av|      27.0|cat1|      54.0|pt4|      81.0|
|  Bob| av|      27.0|cat4|      27.0|pt4|      81.0|
|  Bob|bcd|      29.0|cat1|      54.0|pt4|      81.0|
|  Bob|bcd|      29.0|cat4|      27.0|pt4|      81.0|
|  Bob|abc|      25.0|cat1|      54.0|pt4|      81.0|
|  Bob|abc|      25.0|cat4|      27.0|pt4|      81.0|
+-----+---+----------+----+----------+---+----------+

Thanks in advance

Upvotes: 0

Views: 83

Answers (1)

Mann
Mann

Reputation: 307

This should do the job:

  var lst = List("nw", "cat", "pt")
  lst.foreach {
    col =>
      var window = Window.partitionBy("name", col)
      customers = customers.withColumn("sum_" + col, sum($"min").over(window))
      customers = customers.withColumn("max_" + col, max(customers.col("sum_" + col)).over(window))
      customers = customers.drop("sum_" + col)
  }
  customers.show

It will generate the following output.

+-----+---+----+---+----+------+-------+------+
| name| nw| cat| pt| min|max_nw|max_cat|max_pt|
+-----+---+----+---+----+------+-------+------+
|  Bob| av|cat4|pt4|27.0|  27.0|   27.0|  81.0|
|  Bob|bcd|cat1|pt4|29.0|  29.0|   54.0|  81.0|
|  Bob|abc|cat1|pt4|25.0|  25.0|   54.0|  81.0|
|Alice|bcd|cat2|pt1|55.0|  55.0|   55.0| 150.0|
|Alice|abc|cat1|pt1|45.0|  95.0|   95.0| 150.0|
|Alice|abc|cat1|pt1|50.0|  95.0|   95.0| 150.0|
+-----+---+----+---+----+------+-------+------+

Upvotes: 2

Related Questions