1pluszara
1pluszara

Reputation: 1528

Spark: Mapgroups on a Dataset

I'm trying this mapgroups function on the below dataset and not sure why I'm getting 0 for the "Total Value" column. Am I missing something here??? Please advice

Spark Version - 2.0 Scala Version - 2.11

case class Record(Hour: Int, Category: String,TotalComm: Double, TotalValue: Int)
val ss = (SparkSession)
import ss.implicits._

val df: DataFrame = ss.sparkContext.parallelize(Seq(
(0, "cat26", 30.9, 200), (0, "cat26", 22.1, 100), (0, "cat95", 19.6, 300), (1, "cat4", 1.3, 100),
(1, "cat23", 28.5, 100), (1, "cat4", 26.8, 400), (1, "cat13", 12.6, 250), (1, "cat23", 5.3, 300),
(0, "cat26", 39.6, 30), (2, "cat40", 29.7, 500), (1, "cat4", 27.9, 600), (2, "cat68", 9.8, 100),
(1, "cat23", 35.6, 500))).toDF("Hour", "Category","TotalComm", "TotalValue")

val resultSum = df.as[Record].map(row => ((row.Hour,row.Category),(row.TotalComm,row.TotalValue)))
.groupByKey(_._1).mapGroups{case(k,iter) => (k._1,k._2,iter.map(x => x._2._1).sum,iter.map(y => y._2._2).sum)}
.toDF("KeyHour","KeyCategory","TotalComm","TotalValue").orderBy(asc("KeyHour"))

resultSum.show()

+-------+-----------+---------+----------+
|KeyHour|KeyCategory|TotalComm|TotalValue|
+-------+-----------+---------+----------+
|      0|      cat26|     92.6|         0|
|      0|      cat95|     19.6|         0|
|      1|      cat13|     12.6|         0|
|      1|      cat23|     69.4|         0|
|      1|       cat4|     56.0|         0|
|      2|      cat40|     29.7|         0|
|      2|      cat68|      9.8|         0|
+-------+-----------+---------+----------+  

Upvotes: 2

Views: 13967

Answers (2)

Shaido
Shaido

Reputation: 28332

As Ramesh Maharjan has pointed out, the issue lie in using the iterators twice, which will result in the TotalValue column being 0. However, there is no need to even use groupByKey and mapGroups from the beginning. The same can be acomplished using groupBy and agg which will result in much cleaner and easier to read code. And as a plus, it avoids using the slow groupByKey as well.

The following will work just as well:

val resultSum = df.groupBy($"Hour", $"Category")
  .agg(sum($"TotalComm").as("TotalComm"), sum($"TotalValue").as("TotalValue"))
  .orderBy(asc("Hour"))

Result:

+----+--------+---------+----------+
|Hour|Category|TotalComm|TotalValue|
+----+--------+---------+----------+
|   0|   cat95|     19.6|       300|
|   0|   cat26|     92.6|       330|
|   1|   cat23|     69.4|       900|
|   1|   cat13|     12.6|       250|
|   1|    cat4|     56.0|      1100|
|   2|   cat68|      9.8|       100|
|   2|   cat40|     29.7|       500|
+----+--------+---------+----------+

If you still want to change the names of the Hour and Category columns that is easily done by changing the groupBy to

groupBy($"Hour".as("KeyHour"), $"Category".as("KeyCategory"))

Upvotes: 2

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

iter inside mapGroups is a buffer and computation can be perfomed only once. So when you sum as iter.map(x => x._2._1).sum then there is nothing left in iter buffer and thus iter.map(y => y._2._2).sum operation yields 0 . So you will have to find a mechanism to calculate sum of both in the same iteration

for loop with ListBuffers

for simplicity I have used for loop and ListBuffer to sum both at once

val resultSum = df.as[Record].map(row => ((row.Hour,row.Category),(row.TotalComm,row.TotalValue)))
  .groupByKey(_._1).mapGroups{case(k,iter) => {
  val listBuffer1 = new ListBuffer[Double]
  val listBuffer2 = new ListBuffer[Int]
      for(a <- iter){
        listBuffer1 += a._2._1
        listBuffer2 += a._2._2
      }
      (k._1, k._2, listBuffer1.sum, listBuffer2.sum)
    }}
  .toDF("KeyHour","KeyCategory","TotalComm","TotalValue").orderBy($"KeyHour".asc)

this should give you correct result

+-------+-----------+---------+----------+
|KeyHour|KeyCategory|TotalComm|TotalValue|
+-------+-----------+---------+----------+
|      0|      cat26|     92.6|       330|
|      0|      cat95|     19.6|       300|
|      1|      cat23|     69.4|       900|
|      1|      cat13|     12.6|       250|
|      1|       cat4|     56.0|      1100|
|      2|      cat68|      9.8|       100|
|      2|      cat40|     29.7|       500|
+-------+-----------+---------+----------+

I hope the answer is helpful

Upvotes: 9

Related Questions