Reputation: 53
I have a sorted table that looks like the following.
col1 | col2 |
---|---|
1000 | 1000 |
2600 | 2600 |
3600 | 3600 |
3600 | 4050 |
3600 | 4500 |
I want to create a flag such that it is true when col1 and col2 are both less than 4000. This is easy with
pyspark_df = pyspark_df.withColumn('flag', when((pyspark_df['col1'] <= 4000) & (pyspark_df['col2'] <= 4000), 1).otherwise(0)
However, I also want the first row that fails (in this case row 4) to also have this flag be true. How should I do this?
Upvotes: 0
Views: 46
Reputation: 16172
You could create a lag column and then use bitwiseOR between the two columns.
from pyspark.sql.functions import when, lag, col, monotonically_increasing_id
from pyspark.sql.window import Window
df = spark.createDataFrame(
[[1000,1000],
[2600,2600],
[3600,3600],
[3600,4500],
[3600,4500]],['col1','col2']
)
df = df.withColumn('flag', when((df['col1'] <= 4000) & (df['col2'] <= 4000), 1).otherwise(0))
df = df.withColumn('idx', monotonically_increasing_id())
w = Window().partitionBy().orderBy(col('idx'))
df = df.withColumn('lag', lag('flag', 1).over(w))
df = df.fillna(0, subset='lag')
df = df.withColumn('flag', df.flag.bitwiseOR(df.lag))
df.select('col1','col2','flag').show()
Output
+----+----+----+
|col1|col2|flag|
+----+----+----+
|1000|1000| 1|
|2600|2600| 1|
|3600|3600| 1|
|3600|4500| 1|
|3600|4500| 0|
+----+----+----+
Upvotes: 1