Jahfet
Jahfet

Reputation: 279

Spark: Find the value with the highest occurrence per group over rolling time window

Starting from the following spark data frame:

from io import StringIO
import pandas as pd
from pyspark.sql.functions import col


pd_df = pd.read_csv(StringIO("""device_id,read_date,id,count
device_A,2017-08-05,4041,3
device_A,2017-08-06,4041,3
device_A,2017-08-07,4041,4
device_A,2017-08-08,4041,3
device_A,2017-08-09,4041,3
device_A,2017-08-10,4041,1
device_A,2017-08-10,4045,2
device_A,2017-08-11,4045,3
device_A,2017-08-12,4045,3
device_A,2017-08-13,4045,3"""),infer_datetime_format=True, parse_dates=['read_date'])

df = spark.createDataFrame(pd_df).withColumn('read_date', col('read_date').cast('date'))
df.show()

Output:

+--------------+----------+----+-----+
|device_id     | read_date|  id|count|
+--------------+----------+----+-----+
|      device_A|2017-08-05|4041|    3|
|      device_A|2017-08-06|4041|    3|
|      device_A|2017-08-07|4041|    4|
|      device_A|2017-08-08|4041|    3|
|      device_A|2017-08-09|4041|    3|
|      device_A|2017-08-10|4041|    1|
|      device_A|2017-08-10|4045|    2|
|      device_A|2017-08-11|4045|    3|
|      device_A|2017-08-12|4045|    3|
|      device_A|2017-08-13|4045|    3|
+--------------+----------+----+-----+

I would like to find the most frequent id for each (device_id, read_date) combination, over a 3 day rolling window. For each group of rows selected by the time window, I need to find the most frequent id by summing up the counts per id, then return the top id.

Expected Output:

+--------------+----------+----+
|device_id     | read_date|  id|
+--------------+----------+----+
|      device_A|2017-08-05|4041|
|      device_A|2017-08-06|4041|
|      device_A|2017-08-07|4041|
|      device_A|2017-08-08|4041|
|      device_A|2017-08-09|4041|
|      device_A|2017-08-10|4041|
|      device_A|2017-08-11|4045|
|      device_A|2017-08-12|4045|
|      device_A|2017-08-13|4045|
+--------------+----------+----+

I am starting to think this is only possible using a custom aggregation function. Since spark 2.3 is not out I will have to write this in Scala or use collect_list. Am I missing something?

Upvotes: 3

Views: 9045

Answers (2)

Alper t. Turker
Alper t. Turker

Reputation: 35229

Add window:

from pyspark.sql.functions import window, sum as sum_, date_add

df_w = df.withColumn(
    "read_date", window("read_date", "3 days", "1 day")["start"].cast("date")
)
# Then handle the counts 
df_w = df_w.groupBy('device_id', 'read_date', 'id').agg(sum_('count').alias('count'))

Use one of the solutions from Find maximum row per group in Spark DataFrame for example

from pyspark.sql.window import Window
from pyspark.sql.functions import row_number

rolling_window = 3

top_df = (
    df_w
    .withColumn(
        "rn", 
        row_number().over(
            Window.partitionBy("device_id", "read_date")
            .orderBy(col("count").desc())
        )
    )
    .where(col("rn") == 1)
    .orderBy("read_date")
    .drop("rn")
)

# results are calculated on the start of the time window - adjust read_date as needed

final_df = top_df.withColumn('read_date', date_add('read_date', rolling_window - 1))

final_df.show()

# +---------+----------+----+-----+
# |device_id| read_date|  id|count|
# +---------+----------+----+-----+
# | device_A|2017-08-05|4041|    3|
# | device_A|2017-08-06|4041|    6|
# | device_A|2017-08-07|4041|   10|
# | device_A|2017-08-08|4041|   10|
# | device_A|2017-08-09|4041|   10|
# | device_A|2017-08-10|4041|    7|
# | device_A|2017-08-11|4045|    5|
# | device_A|2017-08-12|4045|    8|
# | device_A|2017-08-13|4045|    9|
# | device_A|2017-08-14|4045|    6|
# | device_A|2017-08-15|4045|    3|
# +---------+----------+----+-----+

Upvotes: 3

Jahfet
Jahfet

Reputation: 279

I managed to find a very inefficient solution. Hopefully someone can spot improvements to avoid the python udf and call to collect_list.

from pyspark.sql import Window
from pyspark.sql.functions import col, collect_list, first, udf
from pyspark.sql.types import IntegerType

def top_id(ids, counts):
    c = Counter()
    for cnid, count in zip(ids, counts):
        c[cnid] += count

    return c.most_common(1)[0][0]


rolling_window = 3

days = lambda i: i * 86400

# Define a rolling calculation window based on time
window = (
    Window()
        .partitionBy("device_id")
        .orderBy(col("read_date").cast("timestamp").cast("long"))
        .rangeBetween(-days(rolling_window - 1), 0)
)

# Use window and collect_list to store data matching the window definition on each row
df_collected = df.select(
    'device_id', 'read_date',
    collect_list(col('id')).over(window).alias('ids'),
    collect_list(col('count')).over(window).alias('counts')
)

# Get rid of duplicate rows where necessary
df_grouped = df_collected.groupBy('device_id', 'read_date').agg(
    first('ids').alias('ids'),
    first('counts').alias('counts'),
)

# Register and apply udf to return the most frequently seen id
top_id_udf = udf(top_id, IntegerType())
df_mapped = df_grouped.withColumn('top_id', top_id_udf(col('ids'), col('counts')))

df_mapped.show(truncate=False)

returns:

+---------+----------+------------------------+------------+------+
|device_id|read_date |ids                     |counts      |top_id|
+---------+----------+------------------------+------------+------+
|device_A |2017-08-05|[4041]                  |[3]         |4041  |
|device_A |2017-08-06|[4041, 4041]            |[3, 3]      |4041  |
|device_A |2017-08-07|[4041, 4041, 4041]      |[3, 3, 4]   |4041  |
|device_A |2017-08-08|[4041, 4041, 4041]      |[3, 4, 3]   |4041  |
|device_A |2017-08-09|[4041, 4041, 4041]      |[4, 3, 3]   |4041  |
|device_A |2017-08-10|[4041, 4041, 4041, 4045]|[3, 3, 1, 2]|4041  |
|device_A |2017-08-11|[4041, 4041, 4045, 4045]|[3, 1, 2, 3]|4045  |
|device_A |2017-08-12|[4041, 4045, 4045, 4045]|[1, 2, 3, 3]|4045  |
|device_A |2017-08-13|[4045, 4045, 4045]      |[3, 3, 3]   |4045  |
+---------+----------+------------------------+------------+------+

Upvotes: 1

Related Questions