Reputation: 2672
I have a Pyspark dataframe with some non-unique key key
and some columns number
and value
.
For most keys
, the number
column goes from 1 to 12, but for some of them, there are gaps in numbers
(for ex. we have numbers [1, 2, 5, 9]
). I would like to add missing rows, so that for every key
we have all the numbers
in range 1-12 populated with the last seen value.
So that for table
key number value
a 1 6
a 2 10
a 5 20
a 9 25
I would like to get
key number value
a 1 6
a 2 10
a 3 10
a 4 10
a 5 20
a 6 20
a 7 20
a 8 20
a 9 25
a 10 25
a 11 25
a 12 25
I thought about creating a table of a
and an array of 1-12, exploding the array and joining with my original table, then separately populating the value
column with previous value using a window function bounded by current row. However, it seems a bit inelegant and I wonder if there is a better way to achieve what I want?
Upvotes: 2
Views: 910
Reputation: 8410
You could do this without join. I have done multiple tests on this with different gaps and it will always work as long as number 1 is always provided as input(as you need sequence to start from there), and it will always range till 12. I used a couple windows to get a column which I could use in the sequence, then made a custom sequence using expressions, and then exploded it to get desired result. If for some reason, you will have inputs that do not have number 1 in there, let me know I will update my solution.
from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import when
w=Window().partitionBy("key").orderBy("number")
w2=Window().partitionBy("key").orderBy("number").rowsBetween(Window.unboundedPreceding,Window.unboundedFollowing)
df.withColumn("number2", F.lag("number").over(w)).withColumn("diff", F.when((F.col("number2").isNotNull()) & ((F.col("number")-F.col("number2")) > 1), (F.col("number")-F.col("number2"))).otherwise(F.lit(0)))\
.withColumn("diff2", F.lead("diff").over(w)).withColumn("diff2", F.when(F.col("diff2").isNull(), F.lit(0)).otherwise(F.col("diff2"))).withColumn("diff2", F.when(F.col("diff2")!=0, F.col("diff2")-1).otherwise(F.col("diff2"))).withColumn("max", F.max("number").over(w2))\
.withColumn("diff2", F.when((F.col("number")==F.col("max")) & (F.col("number")<F.lit(12)), F.lit(12)-F.col("number")).otherwise(F.col("diff2")))\
.withColumn("number2", F.when(F.col("diff2")!=0,F.expr("""sequence(number,number+diff2,1)""")).otherwise(F.expr("""sequence(number,number+diff2,0)""")))\
.drop("diff","diff2","max")\
.withColumn("number2", F.explode("number2")).drop("number")\
.select("key", F.col("number2").alias("number"), "value")\
.show()
+---+------+-----+
|key|number|value|
+---+------+-----+
| a| 1| 6|
| a| 2| 10|
| a| 3| 10|
| a| 4| 10|
| a| 5| 20|
| a| 6| 20|
| a| 7| 20|
| a| 8| 20|
| a| 9| 25|
| a| 10| 25|
| a| 11| 25|
| a| 12| 25|
+---+------+-----+
Upvotes: 1
Reputation: 43534
I thought about creating a table of a and an array of 1-12, exploding the array and joining with my original table, then separately populating the value column with previous value using a window function bounded by current row. However, it seems a bit inelegant and I wonder if there is a better way to achieve what I want?
I do not think your proposed approach is inelegant - but you can achieve the same using range
instead of explode
.
First create a dataframe with all the numbers in your range. You will also want to cross join this with the distinct key
column from your DataFrame.
all_numbers = spark.range(1, 13).withColumnRenamed("id", "number")
all_numbers = all_numbers.crossJoin(df.select("key").distinct()).cache()
all_numbers.show()
#+------+---+
#|number|key|
#+------+---+
#| 1| a|
#| 2| a|
#| 3| a|
#| 4| a|
#| 5| a|
#| 6| a|
#| 7| a|
#| 8| a|
#| 9| a|
#| 10| a|
#| 11| a|
#| 12| a|
#+------+---+
Now you can outer join this to your original DataFrame and forward fill using the last known good value. If the number of keys is small enough, you may be able to broadcast
from pyspark.sql.functions import broadcast, last
from pyspark.sql import Window
df.join(broadcast(all_numbers), on=["number", "key"], how="outer")\
.withColumn(
"value",
last(
"value",
ignorenulls=True
).over(
Window.partitionBy("key").orderBy("number")\
.rowsBetween(Window.unboundedPreceding, 0)
)
)\
.show()
#+------+---+-----+
#|number|key|value|
#+------+---+-----+
#| 1| a| 6|
#| 2| a| 10|
#| 3| a| 10|
#| 4| a| 10|
#| 5| a| 20|
#| 6| a| 20|
#| 7| a| 20|
#| 8| a| 20|
#| 9| a| 25|
#| 10| a| 25|
#| 11| a| 25|
#| 12| a| 25|
#+------+---+-----+
Upvotes: 2