Ganesh Sundar
Ganesh Sundar

Reputation: 311

Aggregation on an array of structs in a map inside a Spark dataframe

I apologize for the verbose title, but I really couldn't come up with something better.

Basically, I have data with the following schema:

 |-- id: string (nullable = true)
 |-- mainkey: map (nullable = true)
 |    |-- key: string
 |    |-- value: array (valueContainsNull = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- price: double (nullable = true)
 |    |    |    |-- recordtype: string (nullable = true)

Let me use the following example data:

{"id":1, "mainkey":{"key1":[{"price":0.01,"recordtype":"BID"}],"key2":[{"price":4.3,"recordtype":"FIXED"}],"key3":[{"price":2.0,"recordtype":"BID"}]}}
{"id":2, "mainkey":{"key4":[{"price":2.50,"recordtype":"BID"}],"key5":[{"price":2.4,"recordtype":"BID"}],"key6":[{"price":0.19,"recordtype":"BID"}]}}

For each of the two records above, I want to calculate mean of all prices when the recordtype is "BID". So, for the first record (with "id":1), we have 2 such bids, with prices 0.01 and 2.0, so the mean rounded to 2 decimal places is 1.01. For the second record (with "id":2), there are 3 bids, with prices 2.5, 2.4 and 0.19, and the mean is 1.70. So I want the following output:

+---+---------+
| id|meanvalue|
+---+---------+
|  1|     1.01|
|  2|      1.7|
+---+---------+

The following code does it:

val exSchema = (new StructType().add("id", StringType).add("mainkey", MapType(StringType, new ArrayType(new StructType().add("price", DoubleType).add("recordtype", StringType), true))))
val exJsonDf = spark.read.schema(exSchema).json("file:///data/json_example")
var explodeExJson = exJsonDf.select($"id",explode($"mainkey")).explode($"value") {
    case Row(recordValue: Seq[Row] @unchecked ) => recordValue.map{ recordValue =>
    val price = recordValue(0).asInstanceOf[Double]
    val recordtype = recordValue(1).asInstanceOf[String]
    RecordValue(price, recordtype)
    }
    }.cache()

val filteredExJson = explodeExJson.filter($"recordtype"==="BID")

val aggExJson = filteredExJson.groupBy("id").agg(round(mean("price"),2).alias("meanvalue")) 

The problem is that it uses an "expensive" explode operation and it becomes a problem when I am dealing with lots of data, especially when there can be a lot of keys in the map.

Please let me know if you can think of a simpler solution, using UDFs or otherwise. Please also keep in mind that I am a beginner to Spark, and hence may have missed some stuff that would be obvious to you.

Any help would be really appreciated. Thanks in advance!

Upvotes: 2

Views: 1985

Answers (1)

Alper t. Turker
Alper t. Turker

Reputation: 35249

If aggregation is limited to a single Row udf will solve this:

import org.apache.spark.util.StatCounter
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row

val meanPrice =  udf((map: Map[String, Seq[Row]]) => {
  val prices = map.values
    .flatMap(x => x)
    .filter(_.getAs[String]("recordtype") == "BID")
    .map(_.getAs[Double]("price"))
  StatCounter(prices).mean
})

df.select($"id", meanPrice($"mainkey"))

Upvotes: 2

Related Questions