guderkar
guderkar

Reputation: 143

pyspark select first element over window on some condition

Problem

Hello is there a way in pyspark/spark to select first element over some window on some condition?

Examples

Let's have an example input dataframe

+---------+----------+----+----+----------------+
|       id| timestamp|  f1|  f2|        computed|
+---------+----------+----+----+----------------+
|        1|2020-01-02|null|c1f2|            [f2]|
|        1|2020-01-01|c1f1|null|            [f1]|
|        2|2020-01-01|c2f1|null|            [f1]|
+---------+----------+----+----+----------------+

I want to select for each id latest column (f1, f2...) that was computed.

So the "code" would look like this

cols = ["f1", "f2"]

w = Window().partitionBy("id").orderBy(f.desc("timestamp")).rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)

output_df = (
    input_df.select(
        "id",
        *[f.first(col, condition=f.array_contains(f.col("computed"), col)).over(w).alias(col) for col in cols]
    )
    .groupBy("id")
    .agg(*[f.first(col).alias(col) for col in cols])
)

And output should be

+---------+----+----+
|       id|  f1|  f2|
+---------+----+----+
|        1|c1f1|c1f2|
|        2|c2f1|null|
+---------+----+----+

If the input looks like this

+---------+----------+----+----+----------------+
|       id| timestamp|  f1|  f2|        computed|
+---------+----------+----+----+----------------+
|        1|2020-01-02|null|c1f2|        [f1, f2]|
|        1|2020-01-01|c1f1|null|            [f1]|
|        2|2020-01-01|c2f1|null|            [f1]|
+---------+----------+----+----+----------------+

Then the output should be

+---------+----+----+
|       id|  f1|  f2|
+---------+----+----+
|        1|null|c1f2|
|        2|c2f1|null|
+---------+----+----+

As you can see it's not easy just to use f.first(ignore_nulls=True) because in this case we don't want to skip the null as it is taken as computed value.

Current solution

Step 1

Save original data types

cols = ["f1", "f2"]
orig_dtypes = [field.dataType for field in input_df.schema if field.name in cols]

Step 2

For Each column create new column with it's value if the column is computed and also replace original null with our "synthetic" <NULL> string

output_df = input_df.select(
    "id", "timestamp", "computed",
    *[
        f.when(f.array_contains(f.col("computed"), col) & f.col(col).isNotNull(), f.col(col))
        .when(f.array_contains(f.col("computed"), col) & f.col(col).isNull(), "<NULL>")
        .alias(col)
        for col in cols
    ]
)

Step 3

Select first non null value over window because now we know that <NULL> won't be skipped

output_df = (
    output_df.select(
        "id",
        *[f.first(col, ignorenulls=True).over(w).alias(col) for col in cols],
    )
    .groupBy("id")
    .agg(*[f.first(col).alias(col) for col in cols])
)

Step 4

Replace our "synthetic" <NULL> for original nulls.

output_df = output_df.replace("<NULL>", None)

Step 5

Cast columns back to it's original types because they might get retyped to string in step 2

output_df = output_df.select("id", *[f.col(col).cast(type_) for col, type_ in zip(cols, orig_dtypes)])

This solution works but it does not seem to be the right way to do it. Besides it's pretty heavy and it's taking too long to get computed.

Is there any other more "sparkish" way to do it?

Upvotes: 3

Views: 1389

Answers (1)

blackbishop
blackbishop

Reputation: 32710

Here's one way by using this trick of struct ordering.

Groupby id and collect list of structs like struct<col_exists_in_computed, timestamp, col_value> for each column in cols list, then using array_max function on the resulting array you get the lasted value you want:

from pyspark.sql import functions as F

output_df = input_df.groupBy("id").agg(
    *[F.array_max(
        F.collect_list(
          F.struct(F.array_contains("computed", c), F.col("timestamp"), F.col(c))
        )
    )[c].alias(c) for c in cols]
)

# applied to you second dataframe example, it gives

output_df.show()
#+---+----+----+
#| id|  f1|  f2|
#+---+----+----+
#|  1|null|c1f2|
#|  2|c2f1|null|
#+---+----+----+

Upvotes: 4

Related Questions