Michael Huang
Michael Huang

Reputation: 11

PySpark withColumn that uses column data from another row

I have a dataframe like this:

order_type customer_id order_id related_order_id
purchase 123 abc null
return 123 bcd null
purchase 234 xyz null
return 234 zzz null

Where I want to fill in the related_order_id column as the order_id of the related purchase, only for rows where order_type is return. A return and a purchase row can be related by their customer_id.

I've tried to use withColumn(), but I haven't figured out a way that would allow me to also look at other rows and their column data. The end result should look something like

order_type customer_id order_id related_order_id
purchase 123 abc null
return 123 bcd abc
purchase 234 xyz null
return 234 zzz xyz

Upvotes: 1

Views: 455

Answers (2)

travelingbones
travelingbones

Reputation: 8408

I think you may be able to do this very efficiently with clever use of joins.

Let's make a "left" and a "right" df so we can join them:

df_l = df[['order_type', 'customer_id', 'order_id']]
df_l = df_l[ df_l.order_type == "return"] # so left df is returns

df_r = df[['order_type', 'customer_id', 'order_id']]\
    .withColumnRenamed('order_id', 'related_order_id')
df_r = df_r[ df_r.order_type == "purchase"] # so right df is purchases 

Now take a left join on customer_id

df_j = df_l.join( df_r, df_l.customer_id == df_r.customer_id, 'left')

this joined df_j will have columns ['order_type', 'customer_id', 'order_id', 'related_order_id'], and the last column should be as you desire. This df_j will only include the rows from the return items. So we can add back the others rows with .union. We need to add the related_order_id column with null values to the purchased dataframe first:

df_final = df_j.union( df_r.withColumn('related_order_id', lit( None )))

My guess is that this will be quite efficient.

Upvotes: 0

Cena
Cena

Reputation: 3419

You can use the lag() function to use data from the previous row.

Assuming a return is always preceded by a purchase, you can do:

from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import col

w = Window().partitionBy("customer_id").orderBy("order_type")

df.withColumn("related_order_id", F.when(col("order_type")=="return", \
              F.lag(col("order_id")).over(w)) \
             .otherwise(col("related_order_id"))).show()

Output:

+----------+-----------+--------+----------------+
|order_type|customer_id|order_id|related_order_id|
+----------+-----------+--------+----------------+
|  purchase|        123|     abc|            null|
|    return|        123|     bcd|             abc|
|  purchase|        234|     xyz|            null|
|    return|        234|     zzz|             xyz|
+----------+-----------+--------+----------------+

Upvotes: 1

Related Questions