R.Muthuu
R.Muthuu

Reputation: 33

PySpark Windows Function with Conditional Reset

I have a dataframe like this

|  user_id  | acivity_date |
|  -------- | ------------ |
| 49630701  | 1/1/2019     |
| 49630701  | 1/10/2019    |
| 49630701  | 1/28/2019    |
| 49630701  | 2/5/2019     |
| 49630701  | 3/10/2019    |
| 49630701  | 3/21/2019    |
| 49630701  | 5/25/2019    |
| 49630701  | 5/28/2019    |
| 49630701  | 9/10/2019    |
| 49630701  | 1/1/2020     |
| 49630701  | 1/10/2020    |
| 49630701  | 1/28/2020    |
| 49630701  | 2/10/2020    |
| 49630701  | 3/10/2020    |

What I would need to create is the "Group" column, the logic is For every User we need to retain the Group # until the cumulative date difference is less than 30 days, whenever the cumulative date difference is greater than 30 days then we need to increment the group # as well as reset the cumulative date difference to zero

|  user_id  | acivity_date | Group |
|  -------- | ------------ | ----- | 
| 49630701  | 1/1/2019     |  1    |
| 49630701  | 1/10/2019    |  1    |
| 49630701  | 1/28/2019    |  1    | 
| 49630701  | 2/5/2019     |  2    | <- Cumulative date diff till here is 35, which is greater than 30, so increment the Group by 1 and reset the cumulative diff to 0 
| 49630701  | 3/10/2019    |  3    |
| 49630701  | 3/21/2019    |  3    |
| 49630701  | 5/25/2019    |  4    |
| 49630701  | 5/28/2019    |  4    |
| 49630701  | 9/10/2019    |  5    |
| 49630701  | 1/1/2020     |  6    |
| 49630701  | 1/10/2020    |  6    |
| 49630701  | 1/28/2020    |  6    |
| 49630701  | 2/10/2020    |  7    |
| 49630701  | 3/10/2020    |  7    |

I tried with the below code with the loop, but it is not efficient, it is running for hours. Is there a better way to achieve this? Any help would be really appreciated

df= spark.read.table('excel_file)
df1 = df.select(col("user_id"), col("activity_date")).distinct()
partitionWindow = Window.partitionBy("user_id").orderBy(col("activity_date").asc())
lagTest = lag(col("activity_date"), 1, "0000-00-00 00:00:00").over(partitionWindow)
df1 = df1.select(col("*"), (datediff(col("activity_date"),lagTest)).cast("int").alias("diff_val_with_previous"))
df1 = df1.withColumn('diff_val_with_previous', when(col('diff_val_with_previous').isNull(), lit(0)).otherwise(col('diff_val_with_previous')))
distinctUser = [i['user_id'] for i in df1.select(col("user_id")).distinct().collect()]
rankTest = rank().over(partitionWindow)
df2 = df1.select(col("*"), rankTest.alias("rank"))

interimSessionThreshold = 30
totalSessionTimeThreshold = 30
rowList = []

for x in distinctUser:
  tempDf = df2.filter(col("user_id") == x).orderBy(col('activity_date'))
  cumulDiff = 0
  group = 1
  startBatch = True
  len_df = tempDf.count()
  dp = 0
  for i in range(1, len_df+1):
    r = tempDf.filter(col("rank") == i)
    dp = r.select("diff_val_with_previous").first()[0]
    cumulDiff += dp
    if ((dp <= interimSessionThreshold) & (cumulDiff <= totalSessionTimeThreshold)):
      startBatch=False
      rowList.append([r.select("user_id").first()[0], r.select("activity_date").first()[0], group])
    else:
      group += 1
      cumulDiff = 0
      startBatch = True
      dp = 0
      rowList.append([r.select("user_id").first()[0], r.select("activity_date").first()[0], group])

ddf = spark.createDataFrame(rowList, ['user_id', 'activity_date', 'group'])

Upvotes: 0

Views: 270

Answers (1)

Steven
Steven

Reputation: 15258

I can think of two solutions but none of them are matching exactly what you want :

from pyspark.sql import functions as F, Window

df.withColumn(
    "idx", F.monotonically_increasing_id()
).withColumn(
    "date_as_num", F.unix_timestamp("activity_date")
).withColumn(
    "group", F.min("idx").over(Window.partitionBy('user_id').orderBy("date_as_num").rangeBetween(- 60 * 60 * 24 * 30, 0))
).withColumn(
    "group", F.dense_rank().over(Window.partitionBy("user_id").orderBy("group"))
).show()

+--------+-------------+----------+-----------+-----+                           
| user_id|activity_date|       idx|date_as_num|group|
+--------+-------------+----------+-----------+-----+
|49630701|   2019-01-01|         0| 1546300800|    1|
|49630701|   2019-01-10|         1| 1547078400|    1|
|49630701|   2019-01-28|         2| 1548633600|    1|
|49630701|   2019-02-05|         3| 1549324800|    2|
|49630701|   2019-03-10|         4| 1552176000|    3|
|49630701|   2019-03-21|         5| 1553126400|    3|
|49630701|   2019-05-25|         6| 1558742400|    4|
|49630701|   2019-05-28|8589934592| 1559001600|    4|
|49630701|   2019-09-10|8589934593| 1568073600|    5|
|49630701|   2020-01-01|8589934594| 1577836800|    6|
|49630701|   2020-01-10|8589934595| 1578614400|    6|
|49630701|   2020-01-28|8589934596| 1580169600|    6|
|49630701|   2020-02-10|8589934597| 1581292800|    7|
|49630701|   2020-03-10|8589934598| 1583798400|    8|
+--------+-------------+----------+-----------+-----+

or

df.withColumn(
    "group",
    F.datediff(
        F.col("activity_date"),
        F.lag("activity_date").over(
            Window.partitionBy("user_id").orderBy("activity_date")
        ),
    ),
).withColumn(
    "group", F.sum("group").over(Window.partitionBy("user_id").orderBy("activity_date"))
).withColumn(
    "group", F.floor(F.coalesce(F.col("group"), F.lit(0)) / 30)
).withColumn(
    "group", F.dense_rank().over(Window.partitionBy("user_id").orderBy("group"))
).show()

+--------+-------------+-----+                                                  
| user_id|activity_date|group|
+--------+-------------+-----+
|49630701|   2019-01-01|    1|
|49630701|   2019-01-10|    1|
|49630701|   2019-01-28|    1|
|49630701|   2019-02-05|    2|
|49630701|   2019-03-10|    3|
|49630701|   2019-03-21|    3|
|49630701|   2019-05-25|    4|
|49630701|   2019-05-28|    4|
|49630701|   2019-09-10|    5|
|49630701|   2020-01-01|    6|
|49630701|   2020-01-10|    6|
|49630701|   2020-01-28|    7|
|49630701|   2020-02-10|    7|
|49630701|   2020-03-10|    8|
+--------+-------------+-----+

Upvotes: 1

Related Questions