Rory Byrne
Rory Byrne

Reputation: 923

How to aggregate data using computed groups

My data is stored in a Spark Data Frame in the form (roughly)

Col1 Col2

A1   -5
B1   -20
C1   7
A2   3
B2   -4
C2   17

I want to turn this into:

Col3 Col4

A    2
B   -24
C    24

(Adding the numbers for A and concatenating X1 and X1 into X)

How can I do this using the Data Frame API?

edit:

The col1 values are actually arbitrary strings (endpoints) which I want to concatenate into one column (span), maybe in the form "A1-A2". I am planning on mapping endpoints to other endpoints in a Map and querying that in my UDF. Can my UDF return None? - let's say I didn't want to include A in col3 at all, but I did want to include B and C, could I add another case to your example so that the A rows are skipped when mapping col1 to col3?

Upvotes: 0

Views: 382

Answers (1)

zero323
zero323

Reputation: 330323

You can simply extract group column and use it as group for aggregation. Assuming your data follows a pattern in the example:

With raw SQL:

case class Record(Col1: String, Col2: Int)

val df = sqlContext.createDataFrame(Seq(
    Record("A1", -5),
    Record("B1", -20),
    Record("C1", 7),
    Record("A2", 3),
    Record("B2", -4),
    Record("C2", 17)))

df.registerTempTable("df")

sqlContext.sql(
    """SELECT col3, sum(col2) AS col4 FROM (
        SELECT col2, SUBSTR(Col1, 1, 1) AS col3 FROM df
    ) tmp GROUP BY col3""").show

+----+----+
|col3|col4|
+----+----+
|   A|  -2|
|   B| -24|
|   C|  24|
+----+----+

With Scala API:

import org.apache.spark.sql.functions.{udf, sum}

val getGroup = udf((s: String) => s.substring(0, 1))

df
  .select(getGroup($"col1").alias("col3"), $"col2")
  .groupBy($"col3")
  .agg(sum($"col2").alias("col4"))

+----+----+
|col3|col4|
+----+----+
|   A|  -2|
|   B| -24|
|   C|  24|
+----+----+

If group pattern is more complex you can simply adjust SUBSTR or getGroup function. For example:

val getGroup = {
  val pattern = "^[A-Z]+".r
    udf((s: String) => pattern.findFirstIn(s) match {
      case Some(g) => g
      case None => "Unknown"
  })
}

Edit :

If you want to ignore some groups you simply add WHERE clause. With raw SQL it is straightforward but with Scala API it requires some effort:

 import org.apache.spark.sql.functions.{not, lit}

 df
   .select(...) // As before
   .where(not($"col3".in(lit("A"))))
   .groupBy(...).agg(...) // As before

If you want to discard multiple columns you can use varargs:

val toDiscard = List("A", "B").map(lit(_))

df
    .select(...)
    .where(not($"col3".in(toDiscard: _*)))
    .groupBy(...).agg(...) // As before

Can my UDF return None?

It cannot but it can return null:

val getGroup2 = udf((s: String) => s.substring(0, 1) match {
    case x if x != "A" => x
    case _ => null: String
})

 df
   .select(getGroup2($"col1").alias("col3"), $"col2")
   .where($"col3".isNotNull)
   .groupBy(...).agg(...) // As before

+----+----+
|col3|col4|
+----+----+
|   B| -24|
|   C|  24|
+----+----+

Upvotes: 2

Related Questions