Reputation: 687
I have the following spark dataframe:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('').getOrCreate()
df = spark.createDataFrame([(1, "a", "2"), (2, "b", "2"),(3, "c", "2"), (4, "d", "2"),
(5, "b", "3"), (6, "b", "3"),(7, "c", "2")], ["nr", "column2", "quant"])
which returns me:
+---+-------+------+
| nr|column2|quant |
+---+-------+------+
| 1| a| 2|
| 2| b| 2|
| 3| c| 2|
| 4| d| 2|
| 5| b| 3|
| 6| b| 3|
| 7| c| 2|
+---+-------+------+
I would like to retrieve the rows where for each 3 groupped rows (from each window where window size is 3) quant column has unique values. as in the following pic:
Here red is window size and each window i keep only green rows where quant is unique:
The ouptput that i would like to get is as following:
+---+-------+------+
| nr|column2|values|
+---+-------+------+
| 1| a| 2|
| 4| d| 2|
| 5| b| 3|
| 7| c| 2|
+---+-------+------+
I am new in spark so, I would appreciate any help. Thanks
Upvotes: 0
Views: 1302
Reputation: 1932
This approach should work for you, assuming grouping 3 records are based on 'nr' column.
Using udf
, which decides whether a record should be selected or not and lag
, is used to get prev rows data.
def tag_selected(index, current_quant, prev_quant1, prev_quant2):
if index % 3 == 1: # first record in each group is always selected
return True
if index % 3 == 2 and current_quant != prev_quant1: # second record will be selected if prev quant is not same as current
return True
if index % 3 == 0 and current_quant != prev_quant1 and current_quant != prev_quant2: # third record will be selected if prev quant are not same as current
return True
return False
tag_selected_udf = udf(tag_selected, BooleanType())
df = spark.createDataFrame([(1, "a", "2"), (2, "b", "2"),(3, "c", "2"), (4, "d", "2"),
(5, "b", "3"), (6, "b", "3"),(7, "c", "2")], ["nr", "column2", "quant"])
window = Window.orderBy("nr")
df = df.withColumn("prev_quant1", lag(col("quant"),1, None).over(window))\
.withColumn("prev_quant2", lag(col("quant"),2, None).over(window)) \
.withColumn("selected",
tag_selected_udf(col('nr'),col('quant'),col('prev_quant1'),col('prev_quant2')))\
.filter(col('selected') == True).drop("prev_quant1","prev_quant2","selected")
df.show()
which results
+---+-------+-----+
| nr|column2|quant|
+---+-------+-----+
| 1| a| 2|
| 4| d| 2|
| 5| b| 3|
| 7| c| 2|
+---+-------+-----+
Upvotes: 2