3yakuya
3yakuya

Reputation: 2672

Pyspark - add missing values per key?

I have a Pyspark dataframe with some non-unique key key and some columns number and value.

For most keys, the number column goes from 1 to 12, but for some of them, there are gaps in numbers (for ex. we have numbers [1, 2, 5, 9]). I would like to add missing rows, so that for every key we have all the numbers in range 1-12 populated with the last seen value.

So that for table

key    number    value
a      1         6
a      2         10
a      5         20
a      9         25

I would like to get

key    number    value
a      1         6
a      2         10
a      3         10
a      4         10
a      5         20
a      6         20
a      7         20
a      8         20
a      9         25
a      10        25
a      11        25
a      12        25

I thought about creating a table of a and an array of 1-12, exploding the array and joining with my original table, then separately populating the value column with previous value using a window function bounded by current row. However, it seems a bit inelegant and I wonder if there is a better way to achieve what I want?

Upvotes: 2

Views: 910

Answers (2)

murtihash
murtihash

Reputation: 8410

You could do this without join. I have done multiple tests on this with different gaps and it will always work as long as number 1 is always provided as input(as you need sequence to start from there), and it will always range till 12. I used a couple windows to get a column which I could use in the sequence, then made a custom sequence using expressions, and then exploded it to get desired result. If for some reason, you will have inputs that do not have number 1 in there, let me know I will update my solution.

from pyspark.sql.window import Window
from pyspark.sql import functions as F
from pyspark.sql.functions import when
w=Window().partitionBy("key").orderBy("number")
w2=Window().partitionBy("key").orderBy("number").rowsBetween(Window.unboundedPreceding,Window.unboundedFollowing)
df.withColumn("number2", F.lag("number").over(w)).withColumn("diff", F.when((F.col("number2").isNotNull()) & ((F.col("number")-F.col("number2")) > 1), (F.col("number")-F.col("number2"))).otherwise(F.lit(0)))\
.withColumn("diff2", F.lead("diff").over(w)).withColumn("diff2", F.when(F.col("diff2").isNull(), F.lit(0)).otherwise(F.col("diff2"))).withColumn("diff2", F.when(F.col("diff2")!=0, F.col("diff2")-1).otherwise(F.col("diff2"))).withColumn("max", F.max("number").over(w2))\
.withColumn("diff2", F.when((F.col("number")==F.col("max")) & (F.col("number")<F.lit(12)), F.lit(12)-F.col("number")).otherwise(F.col("diff2")))\
.withColumn("number2", F.when(F.col("diff2")!=0,F.expr("""sequence(number,number+diff2,1)""")).otherwise(F.expr("""sequence(number,number+diff2,0)""")))\
.drop("diff","diff2","max")\
.withColumn("number2", F.explode("number2")).drop("number")\
.select("key", F.col("number2").alias("number"), "value")\
.show()


+---+------+-----+
|key|number|value|
+---+------+-----+
|  a|     1|    6|
|  a|     2|   10|
|  a|     3|   10|
|  a|     4|   10|
|  a|     5|   20|
|  a|     6|   20|
|  a|     7|   20|
|  a|     8|   20|
|  a|     9|   25|
|  a|    10|   25|
|  a|    11|   25|
|  a|    12|   25|
+---+------+-----+

Upvotes: 1

pault
pault

Reputation: 43534

I thought about creating a table of a and an array of 1-12, exploding the array and joining with my original table, then separately populating the value column with previous value using a window function bounded by current row. However, it seems a bit inelegant and I wonder if there is a better way to achieve what I want?

I do not think your proposed approach is inelegant - but you can achieve the same using range instead of explode.

First create a dataframe with all the numbers in your range. You will also want to cross join this with the distinct key column from your DataFrame.

all_numbers = spark.range(1, 13).withColumnRenamed("id", "number")
all_numbers = all_numbers.crossJoin(df.select("key").distinct()).cache()
all_numbers.show()
#+------+---+
#|number|key|
#+------+---+
#|     1|  a|
#|     2|  a|
#|     3|  a|
#|     4|  a|
#|     5|  a|
#|     6|  a|
#|     7|  a|
#|     8|  a|
#|     9|  a|
#|    10|  a|
#|    11|  a|
#|    12|  a|
#+------+---+

Now you can outer join this to your original DataFrame and forward fill using the last known good value. If the number of keys is small enough, you may be able to broadcast

from pyspark.sql.functions import broadcast, last
from pyspark.sql import Window

df.join(broadcast(all_numbers), on=["number", "key"], how="outer")\
    .withColumn(
        "value", 
        last(
            "value", 
            ignorenulls=True
        ).over(
            Window.partitionBy("key").orderBy("number")\
                .rowsBetween(Window.unboundedPreceding, 0)
        )
    )\
    .show()
#+------+---+-----+
#|number|key|value|
#+------+---+-----+
#|     1|  a|    6|
#|     2|  a|   10|
#|     3|  a|   10|
#|     4|  a|   10|
#|     5|  a|   20|
#|     6|  a|   20|
#|     7|  a|   20|
#|     8|  a|   20|
#|     9|  a|   25|
#|    10|  a|   25|
#|    11|  a|   25|
#|    12|  a|   25|
#+------+---+-----+

Upvotes: 2

Related Questions