Reputation: 450
I want to be able to create a lag value based on the value in one of the columns.
in the data given Qdf is the Question dataframe and Adf the Answer dataframe. I have given an additional explanation column (which I actually dont need in my final data)
from pyspark.sql.window import Window
import pyspark.sql.functions as func
from pyspark.sql.types import *
from pyspark.sql import SQLContext
ID = ['A' for i in range(0,10)]+ ['B' for i in range(0,10)]
Day = range(1,11)+range(1,11)
Delay = [2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 3]
Despatched = [2, 3, 1, 4, 6, 2, 6, 5, 3, 6, 3, 1, 2, 4, 1, 2, 3, 3, 6, 1]
Delivered = [0, 0, 2, 3, 1, 0, 10, 0, 0, 13, 0, 0, 3, 1, 0, 6, 0, 0, 6, 3]
Explanation = ["-", "-", "-", "-", "-", "-", "10 (4+6)", "-", "-", "13 (2+6+5)", "-", "-", "-", "-", "-", "6 (2+4)", "-", "-", "6 (1+2+3)", "-"]
QSchema = StructType([StructField("ID", StringType()),StructField("Day", IntegerType()),StructField("Delay", IntegerType()),StructField("Despatched", IntegerType())])
Qdata = map(list, zip(*[ID,Day,Delay,Despatched]))
Qdf = spark.createDataFrame(Qdata,schema=QSchema)
Qdf.show()
+---+---+-----+----------+
| ID|Day|Delay|Despatched|
+---+---+-----+----------+
| A| 1| 2| 2|
| A| 2| 2| 3|
| A| 3| 2| 1|
| A| 4| 3| 4|
| A| 5| 2| 6|
| A| 6| 4| 2|
| A| 7| 3| 6|
| A| 8| 2| 5|
| A| 9| 2| 3|
| A| 10| 2| 6|
| B| 1| 2| 3|
| B| 2| 2| 1|
| B| 3| 3| 2|
| B| 4| 2| 4|
| B| 5| 4| 1|
| B| 6| 3| 2|
| B| 7| 2| 3|
| B| 8| 2| 3|
| B| 9| 2| 6|
| B| 10| 3| 1|
+---+---+-----+----------+
The despatched quantity should be recorded as delivered after the delay time. Ideally it would be great if I can apply the lag function
on the despatched column based on the delay. The Answer dataset would look like below:
Adata = map(list, zip(*[ID,Day,Delay,Despatched,Delivered,Explanation]))
ASchema = StructType([StructField("ID", StringType()),StructField("Day", IntegerType()),StructField("Delay", IntegerType()),StructField("Despatched", IntegerType()),StructField("Delivered", IntegerType()),StructField("Explanation", StringType())])
Adf = spark.createDataFrame(Adata,schema=ASchema)
Adf.show()
+---+---+-----+----------+---------+-----------+
| ID|Day|Delay|Despatched|Delivered|Explanation|
+---+---+-----+----------+---------+-----------+
| A| 1| 2| 2| 0| -|
| A| 2| 2| 3| 0| -|
| A| 3| 2| 1| 2| -|
| A| 4| 3| 4| 3| -|
| A| 5| 2| 6| 1| -|
| A| 6| 4| 2| 0| -|
| A| 7| 3| 6| 10| 10 (4+6)|
| A| 8| 2| 5| 0| -|
| A| 9| 2| 3| 0| -|
| A| 10| 2| 6| 13| 13 (2+6+5)|
| B| 1| 2| 3| 0| -|
| B| 2| 2| 1| 0| -|
| B| 3| 3| 2| 3| -|
| B| 4| 2| 4| 1| -|
| B| 5| 4| 1| 0| -|
| B| 6| 3| 2| 6| 6 (2+4)|
| B| 7| 2| 3| 0| -|
| B| 8| 2| 3| 0| -|
| B| 9| 2| 6| 6| 6 (1+2+3)|
| B| 10| 3| 1| 3| -|
+---+---+-----+----------+---------+-----------+
I have tried the below code to get a constant lag of 2:
Qdf1=Qdf.withColumn('Delivered_lag',func.lag(Qdf['Despatched'],2).over(Window.partitionBy("ID").orderBy("Day")))
But, when I try to use lag on one column and lag by another column I get the error:
Qdf1=Qdf.withColumn('Delivered_lag',func.lag(Qdf['Despatched'],Qdf['Delay']).over(Window.partitionBy("ID").orderBy("Day")))
TypeError: 'Column' object is not callable
How can I get past this? I am using PySpark version 2.3.1 and python version 2.7.13.
Upvotes: 1
Views: 4768
Reputation: 19540
The lag-function takes a fixed value as count parameter, but what you can do is to create a loop with when and otherwise to get what you want:
from pyspark.sql.window import Window
import pyspark.sql.functions as F
import pyspark.sql.types as T
ID = ['A' for i in range(0,10)]+ ['B' for i in range(0,10)]
#I had to modify this line as I'am working with python3
Day = list(range(1,11))+list(range(1,11))
Delay = [2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 2, 2, 3, 2, 4, 3, 2, 2, 2, 3]
Despatched = [2, 3, 1, 4, 6, 2, 6, 5, 3, 6, 3, 1, 2, 4, 1, 2, 3, 3, 6, 1]
Delivered = [0, 0, 2, 3, 1, 0, 10, 0, 0, 13, 0, 0, 3, 1, 0, 6, 0, 0, 6, 3]
Explanation = ["-", "-", "-", "-", "-", "-", "10 (4+6)", "-", "-", "13 (2+6+5)", "-", "-", "-", "-", "-", "6 (2+4)", "-", "-", "6 (1+2+3)", "-"]
QSchema = T.StructType([T.StructField("ID", T.StringType()),T.StructField("Day", T.IntegerType()),T.StructField("Delay", T.IntegerType()),T.StructField("Despatched", T.IntegerType())])
Qdata = map(list, zip(*[ID,Day,Delay,Despatched]))
Qdf = spark.createDataFrame(Qdata,schema=QSchema)
#until here it was basically your code
#At first we add an empty Delivered_lag column to the Qdf
#That allows us to use the same functionality for all iterations of the following loop
Qdf = Qdf.withColumn('Delivered_lag', F.lit(None).cast(T.IntegerType()))
#Now we loop over the distinctive values of Qdf.delay and run the lag function for every value
#otherwise is necessary to keep the previous calculated values
for delay in Qdf.select('delay').distinct().collect():
Qdf = Qdf.withColumn('Delivered_lag', F.when(Qdf['Delay'] == delay.delay, F.lag(Qdf['Despatched'],delay.delay).over(Window.partitionBy("ID").orderBy("Day"))).otherwise(Qdf['Delivered_lag']))
Qdf.show()
Output:
+---+---+-----+----------+-------------+
| ID|Day|Delay|Despatched|Delivered_lag|
+---+---+-----+----------+-------------+
| B| 1| 2| 3| null|
| B| 2| 2| 1| null|
| B| 3| 3| 2| null|
| B| 4| 2| 4| 1|
| B| 5| 4| 1| 3|
| B| 6| 3| 2| 2|
| B| 7| 2| 3| 1|
| B| 8| 2| 3| 2|
| B| 9| 2| 6| 3|
| B| 10| 3| 1| 3|
| A| 1| 2| 2| null|
| A| 2| 2| 3| null|
| A| 3| 2| 1| 2|
| A| 4| 3| 4| 2|
| A| 5| 2| 6| 1|
| A| 6| 4| 2| 3|
| A| 7| 3| 6| 4|
| A| 8| 2| 5| 2|
| A| 9| 2| 3| 6|
| A| 10| 2| 6| 5|
+---+---+-----+----------+-------------+
Upvotes: 2