salamanka44
salamanka44

Reputation: 944

collect_list keeping order (sql/spark scala)

I have a table like this :

Clients   City   Timestamp
1         NY        0
1         WDC       10
1         NY        11    
2         NY        20
2         WDC       15

What I want as an output is to collect all the cities based on the timestamp (each timestamp has a unique city per user). But without displaying the timestamp. The final list must only contains the cities in order. So, for that example, it gives something like this :

Clients   my_list   Timestamp
1         NY - WDC - NY
2         WDC - NY

Maybe, I should generate a list using timestamp. Then remove the timestamp in this list. I don't know...

I am using spark sql with scala. So,I tried to use collect_list in both sql or in scala but it seems that we lose the ordering after using it.

Can you help me to fix this issue ?

Upvotes: 4

Views: 1541

Answers (3)

Vincent Doba
Vincent Doba

Reputation: 5078

Since spark 2.4, you can apply your first idea of creating an object with timestamp and city, collect those objects as list, sort the list and then drop timestamp for each object in the list:

import org.apache.spark.sql.functions.{array_sort, col, collect_list, struct}

val result = inputDf.groupBy("Clients")
  .agg(
    array_sort(
      collect_list(
        struct(col("Timestamp"), col("City"))
      )
    ).getField("City").as("Cities")
  )

With the following inputDf dataframe:

+------+----+---------+
|Client|City|Timestamp|
+------+----+---------+
|1     |NY  |0        |
|1     |WDC |10       |
|1     |NY  |11       |
|2     |NY  |20       |
|2     |WDC |15       |
+------+----+---------+

You will get the following result dataframe:

+------+-------------+
|Client|Cities       |
+------+-------------+
|1     |[NY, WDC, NY]|
|2     |[WDC, NY]    |
+------+-------------+

Using this method, you will only shuffle your input dataframe once.

Upvotes: 0

apnith
apnith

Reputation: 325

I would simply do the below:

val a = Seq((1,"NY",0),(1,"WDC",10),(1,"NY",11),(2,"NY",20),(2,"WDC",15))
    .toDF("client", "city", "timestamp")

val w = Window.partitionBy($"client").orderBy($"timestamp")
val b = a.withColumn("sorted_list", collect_list($"city").over(w))

Here I used Window to partiton over client and ordered on timestamp At this point you have a dataframe like this:

+------+----+---------+-------------+
|client|city|timestamp|sorted_list  |
+------+----+---------+-------------+
|1     |NY  |0        |[NY]         |
|1     |WDC |10       |[NY, WDC]    |
|1     |NY  |11       |[NY, WDC, NY]|
|2     |WDC |15       |[WDC]        |
|2     |NY  |20       |[WDC, NY]    |
+------+----+---------+-------------+

Here, you created the new column sorted_list has ordered list of values, sorted by timestamp, but you have duplicated rows per client. To remove the duplicated ones, groupBy client and keep the max value in for each group:

val c = b
        .groupBy($"client")
        .agg(max($"sorted_list").alias("sorted_timestamp"))
.show(false)

+------+----------------+
|client|sorted_timestamp|
+------+----------------+
|1     |[NY, WDC, NY]   |
|2     |[WDC, NY]       |
+------+----------------+

Upvotes: 4

sangam.gavini
sangam.gavini

Reputation: 196

# below can be helpful for you to achieve your target
val input_rdd = spark.sparkContext.parallelize(List(("1","NY","0"),("1","WDC","10"),("1","NY","11"),("2","NY","20"),("2","WDC","15")))
val input_df = input_rdd.toDF("clients","city","Timestamp")
val winspec1 = Window.partitionBy($"clients").orderBy($"Timestamp")
val input_df1 = input_df.withColumn("collect", collect_list($"city").over(winspec1))
input_df1.show
Output:
+-------+----+---------+-------------+
|clients|city|Timestamp|      collect|
+-------+----+---------+-------------+
|      1|  NY|        0|         [NY]|
|      1| WDC|       10|    [NY, WDC]|
|      1|  NY|       11|[NY, WDC, NY]|
|      2| WDC|       15|        [WDC]|
|      2|  NY|       20|    [WDC, NY]|
+-------+----+---------+-------------+

val winspec2 = Window.partitionBy($"clients").orderBy($"Timestamp".desc)
input_df1.withColumn("number", row_number().over(winspec2)).filter($"number" === 1).drop($"number").drop($"Timestamp").drop($"city").show
Output:
+-------+-------------+
|clients|      collect|
+-------+-------------+
|      1|[NY, WDC, NY]|
|      2|    [WDC, NY]|

Upvotes: 1

Related Questions