Reputation: 362
I want to write a UDF over a data frame that operates as comparing values of particular row against the values from same group, where the grouping is by multiple keys. As UDFs operate on a single row, I want to write a query that returns values from same group in as a new column value.
For example over this Input:
id | categoryAB | categoryXY | value1 | value2 |
---|---|---|---|---|
1 | A | X | 0.2 | True |
2 | A | X | 0.3 | False |
3 | A | X | 0.2 | True |
4 | B | X | 0.4 | True |
5 | B | X | 0.1 | True |
6 | B | Y | 0.5 | False |
I can add
Expected result:
id | categoryAB | categoryXY | value1 | value2 | group1 | group2 |
---|---|---|---|---|---|---|
1 | A | X | 0.2 | True | [0.2, 0.3, 0.2] | [True, False, True] |
2 | A | X | 0.3 | False | [0.2, 0.3, 0.2] | [True, False, True] |
3 | A | X | 0.2 | True | [0.2, 0.3, 0.2] | [True, False, True] |
4 | B | X | 0.4 | True | [0.4, 0.1] | [True, True] |
5 | B | X | 0.1 | True | [0.4, 0.1] | [True, True] |
6 | B | Y | 0.5 | False | [0.5] | [False] |
To be more clear about grouping, there are 3 groups in this example
I need to implement it in Scala with Spark SQL structures and functions but a generic SQL answer could be guiding.
Upvotes: 1
Views: 153
Reputation: 1026
There might be a more optimized method, but here how I usually do:
val df = Seq(
(1, "A", "X", 0.2, true),
(2, "A", "X", 0.3, false),
(3, "A", "X", 0.2, true),
(4, "B", "X", 0.4, true),
(5, "B", "X", 0.1, true),
(6, "B", "Y", 0.5, false)
).toDF("id", "categoryAB", "categoryXY", "value1", "value2")
df.join(
df.groupBy("categoryAB", "categoryXY")
.agg(
collect_list('value1) as "group1",
collect_list('value2) as "group2"
),
Seq("categoryAB", "categoryXY")
).show()
The idea is that I compute separately the aggregation on categoryAB
and categoryXY
, and then I join the new dataframe to the original one (make sure that df
is cached if it is the result of heavy computations as otherwise it will be computed twice).
Upvotes: 1