Chjul
Chjul

Reputation: 437

Pyspark: how to duplicate a row n time in dataframe?

I've got a dataframe like this and I want to duplicate the row n times if the column n is bigger than one:

A   B   n  
1   2   1  
2   9   1  
3   8   2    
4   1   1    
5   3   3 

And transform like this:

A   B   n  
1   2   1  
2   9   1  
3   8   2
3   8   2       
4   1   1    
5   3   3 
5   3   3 
5   3   3 

I think I should use explode, but I don't understand how it works...
Thanks

Upvotes: 16

Views: 23995

Answers (3)

Ahmed
Ahmed

Reputation: 719

The explode function returns a new row for each element in the given array or map.

One way to exploit this function is to use a udf to create a list of size n for each row. Then explode the resulting array.

from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType
    
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"]) 

+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
+---+---+---+

# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))

+---+---+---------+
|  A|  B|        n|
+---+---+---------+
|  1|  2|      [1]|
|  2|  9|      [1]|
|  3|  8|   [2, 2]|
|  4|  1|      [1]|
|  5|  3|[3, 3, 3]|
+---+---+---------+ 

# now use explode  
df2.withColumn('n', explode(df2.n)).show()

+---+---+---+ 
| A | B | n | 
+---+---+---+ 
|  1|  2|  1| 
|  2|  9|  1| 
|  3|  8|  2| 
|  3|  8|  2| 
|  4|  1|  1| 
|  5|  3|  3| 
|  5|  3|  3| 
|  5|  3|  3| 
+---+---+---+ 

Upvotes: 11

jxc
jxc

Reputation: 14008

With Spark 2.4.0+, this is easier with builtin functions: array_repeat + explode:

from pyspark.sql.functions import expr

df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)], ["A", "B", "n"])

new_df = df.withColumn('n', expr('explode(array_repeat(n,int(n)))'))

>>> new_df.show()
+---+---+---+
|  A|  B|  n|
+---+---+---+
|  1|  2|  1|
|  2|  9|  1|
|  3|  8|  2|
|  3|  8|  2|
|  4|  1|  1|
|  5|  3|  3|
|  5|  3|  3|
|  5|  3|  3|
+---+---+---+

Upvotes: 22

pault
pault

Reputation: 43544

I think the udf answer by @Ahmed is the best way to go, but here is an alternative method, that may be as good or better for small n:

First, collect the maximum value of n over the whole DataFrame:

max_n = df.select(f.max('n').alias('max_n')).first()['max_n']
print(max_n)
#3

Now create an array for each row of length max_n, containing numbers in range(max_n). The output of this intermediate step will result in a DataFrame like:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)])).show()
#+---+---+---+---------+
#|  A|  B|  n|  n_array|
#+---+---+---+---------+
#|  1|  2|  1|[0, 1, 2]|
#|  2|  9|  1|[0, 1, 2]|
#|  3|  8|  2|[0, 1, 2]|
#|  4|  1|  1|[0, 1, 2]|
#|  5|  3|  3|[0, 1, 2]|
#+---+---+---+---------+

Now we explode the n_array column, and filter to keep only the values in the array that are less than n. This will ensure that we have n copies of each row. Finally we drop the exploded column to get the end result:

df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)]))\
    .select('A', 'B', 'n', f.explode('n_array').alias('col'))\
    .where(f.col('col') < f.col('n'))\
    .drop('col')\
    .show()
#+---+---+---+
#|  A|  B|  n|
#+---+---+---+
#|  1|  2|  1|
#|  2|  9|  1|
#|  3|  8|  2|
#|  3|  8|  2|
#|  4|  1|  1|
#|  5|  3|  3|
#|  5|  3|  3|
#|  5|  3|  3|
#+---+---+---+

However, we are creating a max_n length array for each row- as opposed to just an n length array in the udf solution. It's not immediately clear to me how this will scale vs. udf for large max_n, but I suspect the udf will win out.

Upvotes: 3

Related Questions