tessa
tessa

Reputation: 828

Pyspark, get the first date a count increases to the highest count in the list

Using Pyspark, I need to get the first date a count increases to the highest count in the list. So in this sample for AZ, I need the 2021-03-06 row because that's the first date the count of 820198 appears and it's higher than all the other counts. If the counts don't change at all over each date, I just need the earliest date in the log. For AK, that would be 2021-01-23.

+-----+------+----------+
|state| count|  log_date|
+-----+------+----------+
|   AZ|820198|2021-03-07|
|   AZ|820198|2021-03-06|
|   AZ|818784|2021-03-05|
|   AZ|801115|2021-03-03|
|   AK| 46819|2021-03-07|
|   AK| 46819|2021-03-06|
|   AK| 46819|2021-03-05|
|   AK| 46819|2021-01-23|
+-----+------+----------+

This is what I have so far and it seems to work under these specific data conditions. I've been at this three days and every day I've had to rework when a new condition/date is added to the data.

I'd like to know if this can be simplified at all based on the original requirement - it just seems like a lot of partitioning when there is probably a smarter way Pyspark that I don't know.

# start by getting the earliest date a count appears in the list
w = Window.partitionBy("state", "count").orderBy(F.desc("count"))
df = df_merged.withColumn("min_date_count_appears", F.min("log_date").over(w)).orderBy("state")
df.show()

# filter out records with the same count that appear at later dates
df = df.where(F.col("min_date_count_appears") == F.col("log_date")).orderBy("state")
df.show()

# add rank by log_date to get last two logs
w = Window.partitionBy("state").orderBy(F.desc("log_date"), F.desc("count"))
df = df.withColumn("rank", F.dense_rank().over(w))
df.show()

# get the lowest count in the list in case there has been no increase
window = Window.partitionBy("state")
df = df.withColumn("min_count", F.min("count").over(window)).withColumn("min_date", F.min("log_date").over(window.orderBy(F.desc("min_count"))))
df.show()

# get the top two logs to compare their counts. If there is no increase then fall back to the min_count
df1 = df.where(F.col("rank") == 1)
df2 = df.where(F.col("rank") == 2).select("state", "count", "log_date")
df2 = df2.withColumnRenamed("count", "prev_count").withColumnRenamed("log_date", "prev_date")
df = df1.join(df2, "state", "inner")
df.show()

case_increase = F.col("count") > F.col("prev_count")
df = df.withColumn("last_import", F.when(case_increase, F.col("log_date")).otherwise(F.col("min_date")))
df = df.withColumn("days_since_import", F.datediff(F.current_date(), df.last_import))
df.show()

Upvotes: 0

Views: 126

Answers (1)

mck
mck

Reputation: 42352

You can get the minimum date for the maximum count as below:

from pyspark.sql import functions as F, Window

df2 = df.withColumn(
    'max_count',
    F.max('count').over(Window.partitionBy('state'))
).withColumn(
    'first_date',
    F.min(
        F.when(
            F.col('count') == F.col('max_count'), 
            F.col('log_date')
        )
    ).over(Window.partitionBy('state'))
)

df2.show()
+-----+------+----------+---------+----------+
|state| count|  log_date|max_count|first_date|
+-----+------+----------+---------+----------+
|   AZ|820198|2021-03-07|   820198|2021-03-06|
|   AZ|820198|2021-03-06|   820198|2021-03-06|
|   AZ|818784|2021-03-05|   820198|2021-03-06|
|   AZ|801115|2021-03-03|   820198|2021-03-06|
|   AK| 46819|2021-03-07|    46819|2021-01-23|
|   AK| 46819|2021-03-06|    46819|2021-01-23|
|   AK| 46819|2021-03-05|    46819|2021-01-23|
|   AK| 46819|2021-01-23|    46819|2021-01-23|
+-----+------+----------+---------+----------+

If you just want the dates and counts, you can do

df3 = df2.select('state', 'max_count', 'first_date').distinct()

df3.show()
+-----+---------+----------+
|state|max_count|first_date|
+-----+---------+----------+
|   AZ|   820198|2021-03-06|
|   AK|    46819|2021-01-23|
+-----+---------+----------+

Upvotes: 2

Related Questions