JMV12
JMV12

Reputation: 1035

Pyspark Replace DF Value When Value Is In List

I'm trying to write a pyspark script to scrub information from a pyspark df. The df I have looks like:

  hashed_customer     firstname    lastname    email   order_id    status          timestamp
      eater 1_uuid  1_firstname  1_lastname  1_email    12345    OPTED_IN     2020-05-14 20:45:15
      eater 2_uuid  2_firstname  2_lastname  2_email    23456    OPTED_IN     2020-05-14 20:29:22
      eater 3_uuid  3_firstname  3_lastname  3_email    34567    OPTED_IN     2020-05-14 19:31:55
      eater 4_uuid  4_firstname  4_lastname  4_email    45678    OPTED_IN     2020-05-14 17:49:27

I have another pyspark df with the customer I need to remove from the customer_temp_tb table that looks like this:

hashed_customer    eaterstatus
   eater 1_uuid      OPTED_OUT
   eater 3_uuid      OPTED_OUT

I'm trying to find a way to remove the firstname, lastname, and email from the first df if the user is in the second df. So far, I've created a list of the hashed_customers from the second df using:

cust_opt_out_id = [row.hashed_eater_uuid for row in df_out.collect()]

Now, I'm trying to find a way to remove firstname, lastname, and email from the first df if the hashed_customer ID is in the second df so that the end result would look like:

hashed_customer     firstname    lastname    email   order_id    status          timestamp
   eater 1_uuid           NaN         NaN      NaN    12345    OPTED_IN     2020-05-14 20:45:15
   eater 2_uuid   2_firstname  2_lastname  2_email    23456    OPTED_IN     2020-05-14 20:29:22
   eater 3_uuid           NaN         NaN      NaN    34567    OPTED_IN     2020-05-14 19:31:55
   eater 4_uuid   4_firstname  4_lastname  4_email    45678    OPTED_IN     2020-05-14 17:49:27

How can I create a function to do this? I know in pandas it would be a simple:

df_cust_out.loc[df_in['hashed_customer'].isin(cust_opt_out_id),['firstname','lastname', 'email']]=np.nan

But this doesn't work in pyspark.

Upvotes: 2

Views: 167

Answers (1)

anky
anky

Reputation: 75080

If I were to replicate your exact logic , we can do the below (comments inline):

l = df2.select("hashed_customer").collect()
cols_to_update = ['firstname','lastname','email'] # list of cols to update
#use when + otherwise in a loop for the cols_to_update
cond = [F.when(F.col('hashed_customer').isin([i[0] for i in l]),
           F.lit(None)).otherwise(F.col(col)).alias(col) 
           for col in cols_to_update]
#select the changed columns and other columns
final = df1.select(*cond,*[a for a in df1.columns if a not in cols_to_update])
#order as the original dataframe
final.select(*df1.columns).show()

+---------------+-----------+----------+-------+--------+--------+-------------------+
|hashed_customer|  firstname|  lastname|  email|order_id|  status|          timestamp|
+---------------+-----------+----------+-------+--------+--------+-------------------+
|   eater 1_uuid|       null|      null|   null|   12345|OPTED_IN|2020-05-14 20:45:15|
|   eater 2_uuid|2_firstname|2_lastname|2_email|   23456|OPTED_IN|2020-05-14 20:29:22|
|   eater 3_uuid|       null|      null|   null|   34567|OPTED_IN|2020-05-14 19:31:55|
|   eater 4_uuid|4_firstname|4_lastname|4_email|   45678|OPTED_IN|2020-05-14 17:49:27|
+---------------+-----------+----------+-------+--------+--------+-------------------+

Upvotes: 3

Related Questions