user3685285
user3685285

Reputation: 6586

spark get minimum value in column that satisfies a condition

I have a DataFrame in spark that looks like this:

id |  flag
----------
 0 |  true
 1 |  true
 2 | false
 3 |  true
 4 |  true
 5 |  true
 6 | false
 7 | false
 8 |  true
 9 | false

I want to get another Column with the current rowNumber if it has flag == false, or the rowNumber of the next false value, so the output would be like this:

id |  flag | nextOrCurrentFalse
-------------------------------
 0 |  true |                  2
 1 |  true |                  2
 2 | false |                  2
 3 |  true |                  6
 4 |  true |                  6
 5 |  true |                  6
 6 | false |                  6
 7 | false |                  7
 8 |  true |                  9
 9 | false |                  9

I want to do this in a vectorized way (not iterating by row). So I effectively want the logic to be:

Upvotes: 0

Views: 3222

Answers (3)

Ged
Ged

Reputation: 18023

Having thought about scaling and such - but not clear whether Catalyst is good enough - I propose a solution that builds on one of the answers that could benefit from partitioning and has far less work to do - simply by thinking about the data. It's about pre-computation and processing, the point that some massaging can beat brute force approaches. Your point on JOIN is less of an issue as this is a bounded JOIN now and no massive generation of data.

Your comment on dataframe approach is slightly jaded in that all that has surpassed here are dataframes. I think you mean that you want to loop through a Data Frame and have a sub loop with an exit. I can find no such example and in fact I am not sure it fits the SPARK paradigm. Same results gotten, with less processing:

import org.apache.spark.sql.functions._
import spark.implicits._
import org.apache.spark.sql.expressions.Window

val df = Seq((0, true), (1, true), (2,false), (3, true), (4,true), (5,true), (6,false), (7,false), (8,true), (9,false)).toDF("id","flag")
@transient val  w1 = org.apache.spark.sql.expressions.Window.orderBy("id1")  

val ids = df.where("flag = false") 
            .select($"id".as("id1"))  

val ids2 = ids.select($"*", lag("id1",1,-1).over(w1).alias("prev_id"))
val ids3 = ids2.withColumn("prev_id1", col("prev_id")+1).drop("prev_id")

// Less and better performance at scale, this is better theoretically for Catalyst to bound partitions? Less work to do in any event.
// Some understanding of data required! And no grouping and min.
val withNextFalse = df.join(ids3, df("id") >= ids3("prev_id1") && df("id") <= ids3("id1"))
                     .select($"id", $"flag", $"id1".alias("nextOrCurrentFalse"))
                     .orderBy(asc("id"),asc("id"))

withNextFalse.show(false)

returns also:

+---+-----+------------------+
|id |flag |nextOrCurrentFalse|
+---+-----+------------------+
|0  |true |2                 |
|1  |true |2                 |
|2  |false|2                 |
|3  |true |6                 |
|4  |true |6                 |
|5  |true |6                 |
|6  |false|6                 |
|7  |false|7                 |
|8  |true |9                 |
|9  |false|9                 |
+---+-----+------------------+

Upvotes: 2

Ged
Ged

Reputation: 18023

See other answer which is better , but left this here so for SQL educational purposes - possibly.

This does what you want, but I would be keen to know what others thinks of this at scale. I am going to check Catalyst and see how it works procedurally, but I think that may mean some misses at partition bounaries, I am keen to check that as well.

import org.apache.spark.sql.functions._
val df = Seq((0, true), (1, true), (2,false), (3, true), (4,true), (5,true), (6,false), (7,false), (8,true), (9,false)).toDF("id","flag")
df.createOrReplaceTempView("tf") 

// Performance? Need to check at some stage how partitioning works in such a case.
spark.sql("CACHE TABLE tf") 
val res1 = spark.sql("""  
                       SELECT tf1.*, tf2.id as id2, tf2.flag as flag2
                         FROM tf tf1, tf tf2
                        WHERE tf2.id  >= tf1.id
                          AND tf2.flag = false 
                     """)    

//res1.show(false)
res1.createOrReplaceTempView("res1") 
spark.sql("CACHE TABLE res1") 

val res2 = spark.sql(""" SELECT X.id, X.flag, X.id2 
                           FROM (SELECT *, RANK() OVER (PARTITION BY id ORDER BY id2 ASC) as rank_val 
                                   FROM res1) X
                          WHERE X.rank_val = 1
                       ORDER BY id
                    """) 

res2.show(false)

Upvotes: 0

Jason
Jason

Reputation: 209

If flag is fairly sparse, you could do it like this:

val ids = df.where("flag = false"). 
             select($"id".as("id1"))  

val withNextFalse = df.join(ids, df("id") <= ids("id1")).
                      groupBy("id", "flag").
                      agg("id1" -> "min")

In the first step, we make a dataframe of the ids where the flag is false. Then, we join that dataframe to the original data on the desired condition (the original id should be less than or equal to the id of the row where flag is false).

To get the first such case, group by id and use agg to find the minimum value of id1 (which is the id of a row with flag = false.

Running on your example data (and sorting on id) gives the desired output:

+---+-----+--------+
| id| flag|min(id1)|
+---+-----+--------+
|  0| true|       2|
|  1| true|       2|
|  2|false|       2|
|  3| true|       6|
|  4| true|       6|
|  5| true|       6|
|  6|false|       6|
|  7|false|       7|
|  8| true|       9|
|  9|false|       9|
+---+-----+--------+

This approach could run into performance trouble if the DataFrame is very large and has many rows where the flag is False. If that's the case, you may be better off with an iterative solution.

Upvotes: 2

Related Questions