User12345
User12345

Reputation: 5480

update multiple columns based on two columns in pyspark data frames

I have a data frame like below in pyspark.

+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
|     serial_number  |     rest_id  |     value  |     body  |     legs  |     face  |     idle  |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
| sn11               | rs1          | N          | Y         | N         | N         | acde      |
| sn1                | rs1          | N          | Y         | N         | N         | den       |
| sn1                | null         | Y          | N         | Y         | N         | can       |
| sn2                | rs2          | Y          | Y         | N         | N         | aeg       |
| null               | rs2          | N          | Y         | N         | Y         | ueg       |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+

Now I want to update some of the column while checking some column values.

I want to update the value when the any given serial_number or rest_id has value Y then all values of that particular serial_number or rest_id should be updated to Y. if not then what ever values they have.

I have done like below.

df.alias('a').join(df.filter(col('value')='Y').alias('b'),on=(col('a.serial_number') == col('b.serial_number')) | (col('a.rest_id') == col('b.rest_id')), how='left').withColumn('final_value',when(col('b.value').isNull(), col('a.value')).otherwise(col('b.value'))).select('a.serial_number','a.rest_id','a.body', 'a.legs', 'a.face', 'a.idle', 'final_val')

I got the result I want.

Now I want to repeat the same for columns body, legs and face as well.

I can do like above for all columns individually, I mean to say 3 join statements. But I want to update all the 4 columns in a single statement.

How can I do that?

Expected result

+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
|     serial_number  |     rest_id  |     value  |     body  |     legs  |     face  |     idle  |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+
| sn11               | rs1          | N          | Y         | N         | N         | acde      |
| sn1                | rs1          | Y          | Y         | Y         | N         | den       |
| sn1                | null         | Y          | Y         | Y         | N         | can       |
| sn2                | rs2          | Y          | Y         | N         | Y         | aeg       |
| null               | rs2          | Y          | Y         | N         | Y         | ueg       |
+--------------------+--------------+------------+-----------+-----------+-----------+-----------+

Upvotes: 0

Views: 3719

Answers (1)

Ramesh Maharjan
Ramesh Maharjan

Reputation: 41987

You should be using window functions for both serial_number and rest_id columns for checking if Y is present in the columns within that groups. (comments as explanation are provided below)

#column names for looping for the updates
columns = ["value","body","legs","face"]
import sys
from pyspark.sql import window as w
#window for serial number grouping
windowSpec1 = w.Window.partitionBy('serial_number').rowsBetween(-sys.maxint, sys.maxint)
#window for rest id grouping
windowSpec2 = w.Window.partitionBy('rest_id').rowsBetween(-sys.maxint, sys.maxint)

from pyspark.sql import functions as f
from pyspark.sql import types as t
#udf function for checking if Y is in the collected list of windows defined above for the columns in the list defined for looping
def containsUdf(x):
    return "Y" in x

containsUdfCall = f.udf(containsUdf, t.BooleanType())

#looping the columns for checking the condition defined in udf function above by collecting the N and Y in each columns within windows defined
for column in columns:
    df = df.withColumn(column, f.when(containsUdfCall(f.collect_list(column).over(windowSpec1)) | containsUdfCall(f.collect_list(column).over(windowSpec2)), "Y").otherwise(df[column]))

df.show(truncate=False)

which should give you

+-------------+-------+-----+----+----+----+----+
|serial_number|rest_id|value|body|legs|face|idle|
+-------------+-------+-----+----+----+----+----+
|sn2          |rs2    |Y    |Y   |N   |Y   |aeg |
|null         |rs2    |Y    |Y   |N   |Y   |ueg |
|sn11         |rs1    |N    |Y   |N   |N   |acde|
|sn1          |rs1    |Y    |Y   |Y   |N   |den |
|sn1          |null   |Y    |Y   |Y   |N   |can |
+-------------+-------+-----+----+----+----+----+

I would recommend to use the window function separately in two loopings as it might give you memory exceptions for big data as both window functions are used at the same time for each rows

Upvotes: 1

Related Questions