Reputation: 99
Let's say I have a dataframe like this one:
val df = Seq(
("2020-01-01"),
("2020-02-01"),
("2020-02-01"),
("2020-03-01"),
("2020-03-01"),
("2020-03-01"),
("2020-04-01"),
("2020-05-01")
).toDF("date")
+----------+
| date|
+----------+
|2020-01-01|
|2020-02-01|
|2020-02-01|
|2020-03-01|
|2020-03-01|
|2020-03-01|
|2020-04-01|
|2020-05-01|
+----------+
Is there a way to calculate the count of records for previous month like so:
+----------+----------------+
| date|prev_month_count|
+----------+----------------+
|2020-01-01| null|
|2020-02-01| 1|
|2020-02-01| 1|
|2020-03-01| 2|
|2020-03-01| 2|
|2020-03-01| 2|
|2020-04-01| 3|
|2020-05-01| 1|
+----------+----------------+
I can achieve this by deriving another dataframe from the original one, shifting dates, aggregating and joining back to the original dataframe. But I was wondering whether maybe a solution with a combination of lag
and count
using windows exists.
Using lag
with count
and windows both with spark sql:
df.createOrReplaceTempView("test")
sql("""select date, lag(count(date) over(partition by date order by date)) over(order by date) as prev_month_count from test""").show(false)
and the dataframe API:
df
.withColumn("prev_month_count", lag(count($"date").over(Window.partitionBy($"date").orderBy($"date")), 1).over(Window.orderBy($"date")))
.show(false)
I am getting unexpected results:
+----------+----------------+
|date |prev_month_count|
+----------+----------------+
|2020-01-01|null |
|2020-02-01|1 |
|2020-02-01|2 |
|2020-03-01|2 |
|2020-03-01|3 |
|2020-03-01|3 |
|2020-04-01|3 |
|2020-05-01|1 |
+----------+----------------+
What am I missing here?
Upvotes: 0
Views: 975
Reputation: 340
Definitely ! this can be accomplished ,as you rightly said.
sparkSql way of doing this:
EDIT 1: Updated the query to fix "multiple count values for each date" issue.
df.createOrReplaceTempView("test")
spark.sql("""select date,first(lag(count(date) over(partition by date order by date)) over(order by date)) over(partition by date) as prev_month_count from test""").show(false)
+----------+----------------+
|date |prev_month_count|
+----------+----------------+
|2020-01-01|null |
|2020-02-01|1 |
|2020-02-01|1 |
|2020-03-01|2 |
|2020-03-01|2 |
|2020-03-01|2 |
|2020-04-01|3 |
|2020-05-01|1 |
+----------+----------------+
If you don't prefer to use sparkSql ,this can be done using dataframe APIs as well.
Upvotes: 1
Reputation: 22449
Here's one approach that:
count
per partition (by month)temp
populated with null
for all rows except for the first row in each partition, which is assigned the count
from its previous rownull
values from the last non-null value in each partition via Window function last
Sample code:
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy("date").orderBy("date")
val w0 = Window.orderBy("date")
df.
withColumn("count", count("date").over(w)).
withColumn("temp", when(row_number.over(w) === 1, lag("count", 1).over(w0))).
withColumn(
"prev_mo_count",
last("temp", ignoreNulls=true).over(w.rowsBetween(Window.unboundedPreceding, 0))
).
show
/*
+----------+-----+----+-------------+
| date|count|temp|prev_mo_count|
+----------+-----+----+-------------+
|2020-01-01| 1|null| null|
|2020-02-01| 2| 1| 1|
|2020-02-01| 2|null| 1|
|2020-03-01| 3| 2| 2|
|2020-03-01| 3|null| 2|
|2020-03-01| 3|null| 2|
|2020-04-01| 1| 3| 3|
|2020-05-01| 1| 1| 1|
+----------+-----+----+-------------+
*/
Note that column date
is assumed to be always the 1st day of the month (as implied in the provided sample data). In case they consist of arbitrary day values, partitioning should be defined differently, like:
Window.partitionBy(date_format($"date", "yyyyMM")).orderBy("date")
Upvotes: 1