Reputation: 311
I apologize for the verbose title, but I really couldn't come up with something better.
Basically, I have data with the following schema:
|-- id: string (nullable = true)
|-- mainkey: map (nullable = true)
| |-- key: string
| |-- value: array (valueContainsNull = true)
| | |-- element: struct (containsNull = true)
| | | |-- price: double (nullable = true)
| | | |-- recordtype: string (nullable = true)
Let me use the following example data:
{"id":1, "mainkey":{"key1":[{"price":0.01,"recordtype":"BID"}],"key2":[{"price":4.3,"recordtype":"FIXED"}],"key3":[{"price":2.0,"recordtype":"BID"}]}}
{"id":2, "mainkey":{"key4":[{"price":2.50,"recordtype":"BID"}],"key5":[{"price":2.4,"recordtype":"BID"}],"key6":[{"price":0.19,"recordtype":"BID"}]}}
For each of the two records above, I want to calculate mean of all prices when the recordtype is "BID". So, for the first record (with "id":1), we have 2 such bids, with prices 0.01 and 2.0, so the mean rounded to 2 decimal places is 1.01. For the second record (with "id":2), there are 3 bids, with prices 2.5, 2.4 and 0.19, and the mean is 1.70. So I want the following output:
+---+---------+
| id|meanvalue|
+---+---------+
| 1| 1.01|
| 2| 1.7|
+---+---------+
The following code does it:
val exSchema = (new StructType().add("id", StringType).add("mainkey", MapType(StringType, new ArrayType(new StructType().add("price", DoubleType).add("recordtype", StringType), true))))
val exJsonDf = spark.read.schema(exSchema).json("file:///data/json_example")
var explodeExJson = exJsonDf.select($"id",explode($"mainkey")).explode($"value") {
case Row(recordValue: Seq[Row] @unchecked ) => recordValue.map{ recordValue =>
val price = recordValue(0).asInstanceOf[Double]
val recordtype = recordValue(1).asInstanceOf[String]
RecordValue(price, recordtype)
}
}.cache()
val filteredExJson = explodeExJson.filter($"recordtype"==="BID")
val aggExJson = filteredExJson.groupBy("id").agg(round(mean("price"),2).alias("meanvalue"))
The problem is that it uses an "expensive" explode operation and it becomes a problem when I am dealing with lots of data, especially when there can be a lot of keys in the map.
Please let me know if you can think of a simpler solution, using UDFs or otherwise. Please also keep in mind that I am a beginner to Spark, and hence may have missed some stuff that would be obvious to you.
Any help would be really appreciated. Thanks in advance!
Upvotes: 2
Views: 1985
Reputation: 35249
If aggregation is limited to a single Row
udf
will solve this:
import org.apache.spark.util.StatCounter
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.Row
val meanPrice = udf((map: Map[String, Seq[Row]]) => {
val prices = map.values
.flatMap(x => x)
.filter(_.getAs[String]("recordtype") == "BID")
.map(_.getAs[Double]("price"))
StatCounter(prices).mean
})
df.select($"id", meanPrice($"mainkey"))
Upvotes: 2