amateur-coder
amateur-coder

Reputation: 123

Spark : How to preserve order with collect_set in a partition of a dataframe?

I have partitioned and sorted the data by one column (rank) as below

+-------+---------+----+
|classId|studentId|rank|
+-------+---------+----+
|1      |123      |1   |
|1      |5000     |2   |
|1      |5000     |3   |
|1      |5000     |4   |
|1      |908      |5   |
|1      |908      |6   |
|2      |123      |1   |
|2      |123      |2   |
|2      |123      |3   |
|2      |908      |4   |
+-------+---------+----+

Now I want the following output, array of StudentIds in the order of rank column.

+-------+----------------------------------+
|classId|studentIds                        |
+-------+----------------------------------+
|1      |[1234, 5000, 908]                 |
|2      |[1234, 908]                       |
+-------+----------------------------------+

I tried to do collect_list on partition but that gives me duplicates in the correct order

+-------+---------------------------------+
|classId|studentIds                       |
+-------+---------------------------------+
|1      |[123, 5000, 5000, 5000, 908, 908]|
|2      |[123, 123, 123, 908]             |
+-------+---------------------------------+

I tried collect_set on partition that gives me distinct values but incorrect order of Student IDs

+-------+----------------+
|classId|studentIds      |
+-------+----------------+
|1      |[5000, 123, 908]|
|2      |[123, 908]      |
+-------+----------------+

Code:

//Sample Data
val simpleData = Seq(("2", "123", 1),("2", "908", 4),
    ("1", "123", 1),    ("1", "5000", 3),    ("1", "908", 5),    ("1", "5000", 2),
    ("1", "5000", 4),    ("1", "908",6), ("2", "123", 2),    ("2", "123", 3)
  )
val df = simpleData.toDF("classId", "studentId", "rank")

//Processing
df.sort(asc("classId"), asc("rank"))
.withColumn("studentIds", collect_list("studentId")
  .over(Window.partitionBy("classId").orderBy("rank")))
.groupBy("classId")
.agg(last("studentIds") as "studentIds")

Upvotes: 0

Views: 1116

Answers (1)

koiralo
koiralo

Reputation: 23109

You can use array_distinct function to remove duplicates after collect_list as

df.sort(asc("classId"), asc("rank"))
  .withColumn("studentIds", array_distinct(collect_list("studentId")
    .over(Window.partitionBy("classId").orderBy("rank"))))
  .groupBy("classId")
  .agg(last("studentIds") as "studentIds")

Output:

+-------+----------------+
|classId|studentIds      |
+-------+----------------+
|1      |[123, 5000, 908]|
|2      |[123, 908]      |
+-------+----------------+

Upvotes: 3

Related Questions