Lag with count in spark scala

Let's say I have a dataframe like this one:

val df = Seq(
  ("2020-01-01"),
  ("2020-02-01"),
  ("2020-02-01"),
  ("2020-03-01"),
  ("2020-03-01"),
  ("2020-03-01"),
  ("2020-04-01"),
  ("2020-05-01")
).toDF("date")

+----------+
|      date|
+----------+
|2020-01-01|
|2020-02-01|
|2020-02-01|
|2020-03-01|
|2020-03-01|
|2020-03-01|
|2020-04-01|
|2020-05-01|
+----------+

Is there a way to calculate the count of records for previous month like so:

+----------+----------------+
|      date|prev_month_count|
+----------+----------------+
|2020-01-01|            null|
|2020-02-01|               1|
|2020-02-01|               1|
|2020-03-01|               2|
|2020-03-01|               2|
|2020-03-01|               2|
|2020-04-01|               3|
|2020-05-01|               1|
+----------+----------------+

I can achieve this by deriving another dataframe from the original one, shifting dates, aggregating and joining back to the original dataframe. But I was wondering whether maybe a solution with a combination of lag and count using windows exists.

Update

Using lag with count and windows both with spark sql:

df.createOrReplaceTempView("test")
sql("""select date, lag(count(date) over(partition by date order by date)) over(order by date) as prev_month_count from test""").show(false)

and the dataframe API:

df
  .withColumn("prev_month_count", lag(count($"date").over(Window.partitionBy($"date").orderBy($"date")), 1).over(Window.orderBy($"date")))
  .show(false)

I am getting unexpected results:

+----------+----------------+
|date      |prev_month_count|
+----------+----------------+
|2020-01-01|null            |
|2020-02-01|1               |
|2020-02-01|2               |
|2020-03-01|2               |
|2020-03-01|3               |
|2020-03-01|3               |
|2020-04-01|3               |
|2020-05-01|1               |
+----------+----------------+

What am I missing here?

Upvotes: 0

Views: 975

Answers (2)

linusRian
linusRian

Reputation: 340

Definitely ! this can be accomplished ,as you rightly said.

sparkSql way of doing this:

EDIT 1: Updated the query to fix "multiple count values for each date" issue.

df.createOrReplaceTempView("test")

spark.sql("""select date,first(lag(count(date) over(partition by date order by date)) over(order by date)) over(partition by date) as prev_month_count from test""").show(false)

+----------+----------------+
|date      |prev_month_count|
+----------+----------------+
|2020-01-01|null            |
|2020-02-01|1               |
|2020-02-01|1               |
|2020-03-01|2               |
|2020-03-01|2               |
|2020-03-01|2               |
|2020-04-01|3               |
|2020-05-01|1               |
+----------+----------------+


If you don't prefer to use sparkSql ,this can be done using dataframe APIs as well.

Upvotes: 1

Leo C
Leo C

Reputation: 22449

Here's one approach that:

  1. computes count per partition (by month)
  2. generates an intermediary column temp populated with null for all rows except for the first row in each partition, which is assigned the count from its previous row
  3. backfills null values from the last non-null value in each partition via Window function last

Sample code:

import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy("date").orderBy("date")
val w0 = Window.orderBy("date")

df.
  withColumn("count", count("date").over(w)).
  withColumn("temp", when(row_number.over(w) === 1, lag("count", 1).over(w0))).
  withColumn(
    "prev_mo_count",
    last("temp", ignoreNulls=true).over(w.rowsBetween(Window.unboundedPreceding, 0))
  ).
  show
/*
+----------+-----+----+-------------+
|      date|count|temp|prev_mo_count|
+----------+-----+----+-------------+
|2020-01-01|    1|null|         null|
|2020-02-01|    2|   1|            1|
|2020-02-01|    2|null|            1|
|2020-03-01|    3|   2|            2|
|2020-03-01|    3|null|            2|
|2020-03-01|    3|null|            2|
|2020-04-01|    1|   3|            3|
|2020-05-01|    1|   1|            1|
+----------+-----+----+-------------+
*/

Note that column date is assumed to be always the 1st day of the month (as implied in the provided sample data). In case they consist of arbitrary day values, partitioning should be defined differently, like:

Window.partitionBy(date_format($"date", "yyyyMM")).orderBy("date")

Upvotes: 1

Related Questions