wholeroll
wholeroll

Reputation: 193

How to Merge values from multiple rows so they can be processed together - Spark scala

I have multiple database rows per personId with columns that may or may not have values - I'm using colors here as the data is text not numeric so doesn't lend itself to built-in aggregation functions. A simplified example is

PersonId    ColA    ColB    ColB
100         red
100                 green
100                         gold
100         green
110                 yellow
110         white
110   
120         
etc...

I want to be able to decide in a function which column data to use per unique PersonId. A three-way join on the table against itself would be a good solution if the data didn't have multiple values(colors) per column. E.g. that join merges 3 of the rows into one but still produces multiple rows.

PersonId    ColA    ColB    ColB
100         red     green   gold
100         green                                   
110         white   yellow
110   
120

So the solution I'm looking for is something that will allow me to address all the values (colors) for a person in one place (function) so the decision can be made across all their data. The real data of course has more columns but the primary ones for this decision are the three columns. The data is being read in Scala Spark as a Dataframe and I'd prefer using the API to sql. I don't know if any of the exotic windows or groupby functions will help or if it's gonna be down to plain old iterate and accumulate. The technique used in [How to aggregate values into collection after groupBy? might be applicable but it's a bit of a leap.

Upvotes: 2

Views: 2337

Answers (1)

Sarath Chandra Vema
Sarath Chandra Vema

Reputation: 812

Think of using customUDF for doing this.

import org.apache.spark.sql.functions._
val df = Seq((100, "red", null, null), (100, null, "white", null), (100, null, null, "green"), (200, null, "red", null)).toDF("PID", "A", "B", "C")

df.show()
+---+----+-----+-----+
|PID|   A|    B|    C|
+---+----+-----+-----+
|100| red| null| null|
|100|null|white| null|
|100|null| null|green|
|200|null|  red| null|
+---+----+-----+-----+

val customUDF = udf((array: Seq[String]) => {
    val newts = array.filter(_.nonEmpty)
    if  (newts.size == 0) null
    else newts.head
})

df.groupBy($"PID").agg(customUDF(collect_set($"A")).as("colA"), customUDF(collect_set($"B")).as("colB"), customUDF(collect_set($"C")).as("colC")).show

+---+----+-----+-----+
|PID|colA| colB| colC|
+---+----+-----+-----+
|100| red|white|green|
|200|null|  red| null|
+---+----+-----+-----+


Upvotes: 1

Related Questions