Reputation: 5480
I have a data frame like below in pyspark
.
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
| serial_number | rest_id | value | body | legs | face | idle |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
| sn11 | rs1 | N | Y | N | N | acde |
| sn1 | rs1 | N | Y | N | N | den |
| sn1 | null | Y | N | Y | N | can |
| sn2 | rs2 | Y | Y | N | N | aeg |
| null | rs2 | N | Y | N | Y | ueg |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
Now I want to update
some of the column while checking some column values.
I want to update the value
when the any given serial_number
or rest_id
has value Y
then all values
of that particular serial_number
or rest_id
should be updated to Y. if not then what ever values they have.
I have done like below.
df.alias('a').join(df.filter(col('value')='Y').alias('b'),on=(col('a.serial_number') == col('b.serial_number')) | (col('a.rest_id') == col('b.rest_id')), how='left').withColumn('final_value',when(col('b.value').isNull(), col('a.value')).otherwise(col('b.value'))).select('a.serial_number','a.rest_id','a.body', 'a.legs', 'a.face', 'a.idle', 'final_val')
I got the result I want.
Now I want to repeat the same for columns body
, legs
and face
as well.
I can do like above for all columns individually
, I mean to say 3
join statements. But I want to update all the 4
columns in a single statement.
How can I do that?
Expected result
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
| serial_number | rest_id | value | body | legs | face | idle |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
| sn11 | rs1 | N | Y | N | N | acde |
| sn1 | rs1 | Y | Y | Y | N | den |
| sn1 | null | Y | Y | Y | N | can |
| sn2 | rs2 | Y | Y | N | Y | aeg |
| null | rs2 | Y | Y | N | Y | ueg |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
Upvotes: 0
Views: 3719
Reputation: 41987
You should be using window
functions for both serial_number
and rest_id
columns for checking if Y is present in the columns within that groups. (comments as explanation are provided below)
#column names for looping for the updates
columns = ["value","body","legs","face"]
import sys
from pyspark.sql import window as w
#window for serial number grouping
windowSpec1 = w.Window.partitionBy('serial_number').rowsBetween(-sys.maxint, sys.maxint)
#window for rest id grouping
windowSpec2 = w.Window.partitionBy('rest_id').rowsBetween(-sys.maxint, sys.maxint)
from pyspark.sql import functions as f
from pyspark.sql import types as t
#udf function for checking if Y is in the collected list of windows defined above for the columns in the list defined for looping
def containsUdf(x):
return "Y" in x
containsUdfCall = f.udf(containsUdf, t.BooleanType())
#looping the columns for checking the condition defined in udf function above by collecting the N and Y in each columns within windows defined
for column in columns:
df = df.withColumn(column, f.when(containsUdfCall(f.collect_list(column).over(windowSpec1)) | containsUdfCall(f.collect_list(column).over(windowSpec2)), "Y").otherwise(df[column]))
df.show(truncate=False)
which should give you
+-------------+-------+-----+----+----+----+----+
|serial_number|rest_id|value|body|legs|face|idle|
+-------------+-------+-----+----+----+----+----+
|sn2 |rs2 |Y |Y |N |Y |aeg |
|null |rs2 |Y |Y |N |Y |ueg |
|sn11 |rs1 |N |Y |N |N |acde|
|sn1 |rs1 |Y |Y |Y |N |den |
|sn1 |null |Y |Y |Y |N |can |
+-------------+-------+-----+----+----+----+----+
I would recommend to use the window function separately in two loopings as it might give you memory exceptions for big data as both window functions are used at the same time for each rows
Upvotes: 1