masterofnone
masterofnone

Reputation: 65

Conditions in Spark window function

I have a dataframe like

+---+---+---+---+
|  q|  w|  e|  r|
+---+---+---+---+
|  a|  1| 20|  y|
|  a|  2| 22|  z|
|  b|  3| 10|  y|
|  b|  4| 12|  y|
+---+---+---+---+

I want to mark the rows with the minimum e and r = z . If there are no rows which have r = z, I want the row with the minimum e, even if r = y. Essentially, something like

+---+---+---+---+---+
|  q|  w|  e|  r|  t|
+---+---+---+---+---+
|  a|  1| 20|  y|  0|
|  a|  2| 22|  z|  1|
|  b|  3| 10|  y|  1|
|  b|  4| 12|  y|  0|
+---+---+---+---+---+

I can do it using a number of joins, but that would be too expensive. So I was looking for a window-based solution.

Upvotes: 3

Views: 2853

Answers (3)

masterofnone
masterofnone

Reputation: 65

Adding the spark-scala version of @werner 's accepted answer

val w = Window.partitionBy("q")

df.withColumn("min_e_with_r_eq_z", min(when($"r" === "z", $"e").otherwise(null)).over(w))
  .withColumn("min_e_overall", min("e").over(w))
  .withColumn("t", coalesce($"min_e_with_r_eq_z", $"min_e_overall") === $"e")
  .orderBy("w")
  .show()

Upvotes: 3

mck
mck

Reputation: 42332

You can assign row numbers based on whether r = z and the value of column e:

from pyspark.sql import functions as F, Window

df2 = df.withColumn(
    't', 
     F.when(
        F.row_number().over(
            Window.partitionBy('q')
                  .orderBy((F.col('r') == 'z').desc(), 'e')
        ) == 1, 
        1
    ).otherwise(0)
)

df2.show()
+---+---+---+---+---+
|  q|  w|  e|  r|  t|
+---+---+---+---+---+
|  a|  2| 22|  z|  1|
|  a|  1| 20|  y|  0|
|  b|  3| 10|  y|  1|
|  b|  4| 12|  y|  0|
+---+---+---+---+---+

Upvotes: 2

werner
werner

Reputation: 14845

You can calculate the minimum per group once for rows with r = z and then for all rows within a group. The first non-null value can then be compared to e:

from pyspark.sql import functions as F
from pyspark.sql import Window

df = ...

w = Window.partitionBy("q")
#When ordering is not defined, an unbounded window frame is used by default.

df.withColumn("min_e_with_r_eq_z", F.expr("min(case when r='z' then e else null end)").over(w)) \
    .withColumn("min_e_overall", F.min("e").over(w)) \
    .withColumn("t", F.coalesce("min_e_with_r_eq_z","min_e_overall") == F.col("e")) \
    .orderBy("w") \
    .show()

Output:

+---+---+---+---+-----------------+-------------+-----+
|  q|  w|  e|  r|min_e_with_r_eq_z|min_e_overall|    t|
+---+---+---+---+-----------------+-------------+-----+
|  a|  1| 20|  y|               22|           20|false|
|  a|  2| 22|  z|               22|           20| true|
|  b|  3| 10|  y|             null|           10| true|
|  b|  4| 12|  y|             null|           10|false|
+---+---+---+---+-----------------+-------------+-----+

Note: I assume that q is the grouping column for the window.

Upvotes: 4

Related Questions