tusworten
tusworten

Reputation: 21

Group by and aggregate tuples in Spark SQL

I'm working with Spark SQL and Java. I have a dataset with duplicate clients grouped by ENTITY and DOCUMENT_ID. I added a rownumber column to know how many clients are there (and I have to compare) for each group:

.withColumn( "ROWNUMBER", row_number().over(Window.partitionBy("ENTITY", "ENTITY_DOC").orderBy("ID")))
+---------+----------+-----------------+----------+-----------+-----------+--------+------------+
|ROWNUMBER|  ENTITY  |      ENTITY_DOC |    ID    |  BLOCK    |  TYPE_DOC |COD_BEST|COD_CRITERIO|
+---------+----------+-----------------+----------+-----------+-----------+--------+------------+
|        1|       182|000004693R       |   5254578|          3|         01|       0|           0|
|        2|       182|000004693R       |  99841470|          0|         01|       0|           0|
|        3|       182|000004693R       |  45866239|          3|         01|       0|           0|
|        1|       182|000081638B       |  99804050|          0|         01|       0|           0|
|        2|       182|000081638B       |  99803968|          0|         01|       0|           0|
|        3|       182|000081638B       |  99803958|          0|         01|       0|           0|
|        4|       182|000081638B       |  99804054|          0|         01|       0|           0|
|        5|       182|000081638B       |  99787706|          1|         01|       0|           0|
|        6|       182|000081638B       |  99803930|          0|         01|       0|           0|
|        1|       182|000107084L       |  99819126|          0|         01|       0|           0|
|        2|       182|000107084L       |  99818446|          0|         01|       0|           0|
+---------+----------+-----------------+----------+-----------+-----------+--------+------------+

Now I have to compare pairs of rows in order to decide which is the best.

First compare rownumber1 vs rownumber2 (if rownumber2 is the best) then
compare rownumber2 vs rownumber3 (if rownumber3 is the best)
then compare rownumber3 vs rownumber4 ... etc
 

It is decided which is the best based on certain business criteria like:

//criteria 1
BLOCK = 1 VS BLOCK 1
   //go to the next criteria
//criteria 2
BLOCK = 2 VS BLOCK 1
   //the best is BLOCK 2

//criteria3
TYPE_DOC = 1 VS TYPE_DOC = 1
 //go to the next criteria

//criteria4
 TYPE_DOC = 1 VS TYPE_DOC = 2
  //the best is TYPE_DOC 1

(not is a logical example but to get an idea)


In the end I have to know which is the best row of each group and by what criteria it has been selected, but I don't know how to iterate each group to compare the fields of its rows.

Would it be very difficult to do?

Upvotes: 1

Views: 456

Answers (1)

blackbishop
blackbishop

Reputation: 32690

You can first assign row_number for each duplicate then create a map from the columns and add the row number as suffix for each key. Finally group by ENTITY, DOCUMENT_ID and collect the list of maps, concatenate them and pivot after exploding the map.

Note that here I'm mainly using SQL expressions as I'm not very familiar with Spark Java API but the logic is the same if you want to convert them to use API functions.

Dataset < Row > tuples = duplicates.withColumn(
    "rn",
    row_number().over(Window.partitionBy("ENTITY", "DOCUMENT_ID").orderBy("ID"))
).withColumn(
    "dupes",
    expr("map(concat('COUNTRY_', rn), COUNTRY, concat('ID_', rn), ID, concat('CUSTOMER_NAME_', rn), CUSTOMER_NAME)")
).groupBy("ENTITY", "DOCUMENT_ID").agg(
    collect_list("dupes").alias("dupes")
).selectExpr(
    "ENTITY",
    "DOCUMENT_ID",
    "explode(aggregate(dupes, cast(map() as map<string,string>), (acc, x) -> map_concat(acc, x)))"
).groupBy(
    "ENTITY", "DOCUMENT_ID"
).pivot("key").agg(first("value"));


tuples.show();

//+------+-----------+---------+---------+---------------+---------------+--------+--------+
//|ENTITY|DOCUMENT_ID|COUNTRY_1|COUNTRY_2|CUSTOMER_NAME_1|CUSTOMER_NAME_2|    ID_1|    ID_2|
//+------+-----------+---------+---------+---------------+---------------+--------+--------+
//|    11|  A06804173|        9|        9|     Elton John|     Elton John|12341000|13701921|
//+------+-----------+---------+---------+---------------+---------------+--------+--------+

Another way would be to group by your key columns and collect list of structs then using the max size of the resulting arrays, access the elements and create multiples columns. Something like this:

import java.util.stream.IntStream;
import java.util.stream.Stream;

Dataset<Row> tuples = duplicates.groupBy("ENTITY", "DOCUMENT_ID").agg(
    collect_list(expr("struct(COUNTRY, ID, CUSTOMER_NAME)")).alias("dupes")
);

int maxSize = tuples.select(max(size(col("dupes")))).first().getInt(0);

Column[] dupes = IntStream.rangeClosed(0, maxSize - 1)
        .mapToObj(i -> new Column[]{
                col("dupes").getItem(i).getField("COUNTRY").alias("COUNTRY_" + i),
                col("dupes").getItem(i).getField("ID").alias("ID_" + i),
                col("dupes").getItem(i).getField("CUSTOMER_NAME").alias("CUSTOMER_NAME_" + i),
        }).flatMap(Stream::of).toArray(Column[]::new);

tuples.select(
    Stream.of(new Column[]{col("ENTITY"), col("DOCUMENT_ID")}, dupes)
            .flatMap(Stream::of).toArray(Column[]::new)
).show();

Upvotes: 1

Related Questions