Vijay_Shinde
Vijay_Shinde

Reputation: 1352

How Aggregate value without explode in Spark using scala language

I am using Spark 2.2 version and Scala as a programming language.

Input data:

{"amount":"2.00","cal_group":[{}],"set_id":7057} {"amount":"1.00","cal_group":[{}],"set_id":7057} {"amount":"7.00","cal_group": [{"abc_cd":"abc00160","abc_cnt":6.0,"cde_cnt":7.0},{"abc_cd":"abc00160","abc_cnt":5.0,"cde_cnt":2.0},{"abc_cd":"abc00249","abc_cnt":0.0,"cde_cnt":1.0}],"set_id":7057}

Input dataframe:

[2.00,WrappedArray([null,null,null]),7057]
[1.00,WrappedArray([null,null,null]),7057]
[7.00,WrappedArray([abc00160,6.0,7.0],[abc00160,5.0,2.0,],[abc00249,0.0,1.0]),7057]

Input data schema:

|-- amount: string (nullable = true)
|-- cal_group: array (nullable = true)
|    |-- element: struct (containsNull = true)
|    |    |-- abc_cd: string (nullable = true)
|    |    |-- abc_cnt: double (nullable = true)
|    |    |-- cde_cnt: double (nullable = true)
|--set_id: double

Note: Each wrapped array is a struct that contains abc_cd and some 2 other measures columns.

I want to do two level of aggregation on input data. It is mention as Step 1 and Step 2.

Step 1:

We need to get the sum of amount for each set_id and remove nulls while doing collect_list for cal_group

I have tried below code:

val res1=res.groupBy($"set_id").agg(sum($"amount").as('amount_total),collect_list(struct($"cal_group")).as('finalgroup))

It is giving me sum amount as expected. But here I don't know how to skip null WrappedArray cal_group column.

Output: step 1

[7057,10.00,WrappedArray([WrappedArray([null,null,null])],[WrappedArray([null,null,null])],[WrappedArray([null,null,null])],[WrappedArray([abc00160,6.0,7.0],[abc00160,5.0,2.0],[abc00249,0.0,1.0])])

Step 2:

Then I want aggregate 2 measures(abc_cnt, cde_cnt) at abc_cd code level.

Here this aggregation can be done by explode function on cal_group column. It will convert cal_group records at row level, It will increase rows/volume of data.

So, I tried exploding the struct and did group by on abc_cd.

Sample code if use explode function to do sum:

   val res2 = res1.select($"set_id",$"amount_total",explode($"cal_group").as("cal_group"))
    val res1 = res2.select($"set_id",$"amount_total",$"cal_group")
                         .groupBy($"set_id",$"cal_group.abc_cd")
                         .agg(sum($"cal_group.abc_cnt").as('abc_cnt_sum),
                              sum($"cal_group.cde_cnt").as('cde_cnt_sum),
                              )

So here, I don't want to explode the col_group column. as it is increasing the volume.

Output expected after Step 2:

[7057,10.00,WrappedArray(**[WrappedArray([null,null,null])],
                                       [WrappedArray([null,null,null])],
                                       [WrappedArray([null,null,null])],
                                       [WrappedArray([abc00160,11.0,9.0],
                                                     [abc00249,0.0,1.0])])

Is there any option available, where the function should aggregate at record level and remove the null struct before collecting.

Thanks in advance.

Upvotes: 2

Views: 1632

Answers (2)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41957

You can define a udf function for the second part aggregation as

import org.apache.spark.sql.functions._
def aggregateUdf = udf((nestedArray: Seq[Seq[Row]])=>
  nestedArray
    .flatMap(x => x
      .map(y => (y(0).asInstanceOf[String], (y(1).asInstanceOf[Double], y(2).asInstanceOf[Double]))))
      .filterNot(_._1 == null)
      .groupBy(_._1)
      .map(x => (x._1, x._2.map(_._2._1).sum, x._2.map(_._2._2).sum)).toArray
)

And you can call the udf function after your first aggregation (that too need modification by removing the struct part)

val finalRes=res
  .groupBy($"set_id")
  .agg(sum($"amount").as('amount_total),collect_list($"cal_group").as('finalgroup))
  .withColumn("finalgroup", aggregateUdf('finalgroup))

so the finalRes would be

+------+------------+-----------------------------------------+
|set_id|amount_total|finalgroup                               |
+------+------------+-----------------------------------------+
|7057  |10.0        |[[abc00249,0.0,1.0], [abc00160,11.0,9.0]]|
+------+------------+-----------------------------------------+

Upvotes: 1

sujit
sujit

Reputation: 2328

I took below json data and loaded to get an identical schema as yours:

{"amount":"2.00","cal_group":[{}],"set_id":7057.0}
{"amount":"1.00","cal_group":[{}],"set_id":7057}
{"amount":"7.00","cal_group": [{"abc_cd":"abc00160","abc_cnt":6.0,"cde_cnt":7.0},{"abc_cd":"abc00160","abc_cnt":5.0,"cde_cnt":2.0},{"abc_cd":"abc00249","abc_cnt":0.0,"cde_cnt":1.0}],"set_id":7057}

But here I don't know how to skip null WrappedArray cal_group column

I think collect_list automatically removes null, but in your case it isn't able to, since you have used a struct for aggregation that's not required. So, the correct transformation for Step 1 is :

val res1=res.groupBy($"set_id").agg(sum($"amount").as('amount_total),(collect_list($"cal_group")).as('finalgroup))

which gives below output (show and printSchema)

+------+------------+--------------------------------------------------------------------------+
|set_id|amount_total|finalgroup                                                                |
+------+------------+--------------------------------------------------------------------------+
|7057.0|10.0        |[WrappedArray([abc00160,6.0,7.0], [abc00160,5.0,2.0], [abc00249,0.0,1.0])]|
+------+------------+--------------------------------------------------------------------------+
root
 |-- set_id: double (nullable = true)
 |-- amount_total: double (nullable = true)
 |-- finalgroup: array (nullable = true)
 |    |-- element: array (containsNull = true)
 |    |    |-- element: struct (containsNull = true)
 |    |    |    |-- abc_cd: string (nullable = true)
 |    |    |    |-- abc_cnt: double (nullable = true)
 |    |    |    |-- cde_cnt: double (nullable = true)

Step 2

Below assumes above code is run as Step 1. I am using explode mechanism only.

To handle your data structure, you have to do explode twice, as the structure post amount grouping for cal_group is a doubly nested array. Below is the code which gives the desired o/p:

val res2 = res1.select($"set_id",$"amount_total",explode($"finalgroup").as("cal_group"))
val res3 = res2.select($"set_id",$"amount_total",explode($"cal_group").as("cal_group_exp"))
val res4 = res3.groupBy($"set_id",$"cal_group_exp.abc_cd")
                          .agg(sum($"cal_group_exp.abc_cnt").as('abc_cnt_sum),
                              sum($"cal_group_exp.cde_cnt").as('cde_cnt_sum))
res4.show(false)

with output:

+------+--------+-----------+-----------+
|set_id|  abc_cd|abc_cnt_sum|cde_cnt_sum|
+------+--------+-----------+-----------+
|7057.0|abc00160|       11.0|        9.0|
|7057.0|abc00249|        0.0|        1.0|
+------+--------+-----------+-----------+

Upvotes: 0

Related Questions