Reputation: 437
I've got a dataframe like this and I want to duplicate the row n times if the column n
is bigger than one:
A B n
1 2 1
2 9 1
3 8 2
4 1 1
5 3 3
And transform like this:
A B n
1 2 1
2 9 1
3 8 2
3 8 2
4 1 1
5 3 3
5 3 3
5 3 3
I think I should use explode
, but I don't understand how it works...
Thanks
Upvotes: 16
Views: 23995
Reputation: 719
The explode function returns a new row for each element in the given array or map.
One way to exploit this function is to use a udf
to create a list of size n
for each row. Then explode the resulting array.
from pyspark.sql.functions import udf, explode
from pyspark.sql.types import ArrayType, IntegerType
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)] ,["A", "B", "n"])
+---+---+---+
| A| B| n|
+---+---+---+
| 1| 2| 1|
| 2| 9| 1|
| 3| 8| 2|
| 4| 1| 1|
| 5| 3| 3|
+---+---+---+
# use udf function to transform the n value to n times
n_to_array = udf(lambda n : [n] * n, ArrayType(IntegerType()))
df2 = df.withColumn('n', n_to_array(df.n))
+---+---+---------+
| A| B| n|
+---+---+---------+
| 1| 2| [1]|
| 2| 9| [1]|
| 3| 8| [2, 2]|
| 4| 1| [1]|
| 5| 3|[3, 3, 3]|
+---+---+---------+
# now use explode
df2.withColumn('n', explode(df2.n)).show()
+---+---+---+
| A | B | n |
+---+---+---+
| 1| 2| 1|
| 2| 9| 1|
| 3| 8| 2|
| 3| 8| 2|
| 4| 1| 1|
| 5| 3| 3|
| 5| 3| 3|
| 5| 3| 3|
+---+---+---+
Upvotes: 11
Reputation: 14008
With Spark 2.4.0+, this is easier with builtin functions: array_repeat + explode:
from pyspark.sql.functions import expr
df = spark.createDataFrame([(1,2,1), (2,9,1), (3,8,2), (4,1,1), (5,3,3)], ["A", "B", "n"])
new_df = df.withColumn('n', expr('explode(array_repeat(n,int(n)))'))
>>> new_df.show()
+---+---+---+
| A| B| n|
+---+---+---+
| 1| 2| 1|
| 2| 9| 1|
| 3| 8| 2|
| 3| 8| 2|
| 4| 1| 1|
| 5| 3| 3|
| 5| 3| 3|
| 5| 3| 3|
+---+---+---+
Upvotes: 22
Reputation: 43544
I think the udf
answer by @Ahmed is the best way to go, but here is an alternative method, that may be as good or better for small n
:
First, collect the maximum value of n
over the whole DataFrame:
max_n = df.select(f.max('n').alias('max_n')).first()['max_n']
print(max_n)
#3
Now create an array for each row of length max_n
, containing numbers in range(max_n)
. The output of this intermediate step will result in a DataFrame like:
df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)])).show()
#+---+---+---+---------+
#| A| B| n| n_array|
#+---+---+---+---------+
#| 1| 2| 1|[0, 1, 2]|
#| 2| 9| 1|[0, 1, 2]|
#| 3| 8| 2|[0, 1, 2]|
#| 4| 1| 1|[0, 1, 2]|
#| 5| 3| 3|[0, 1, 2]|
#+---+---+---+---------+
Now we explode the n_array
column, and filter to keep only the values in the array that are less than n
. This will ensure that we have n
copies of each row. Finally we drop the exploded column to get the end result:
df.withColumn('n_array', f.array([f.lit(i) for i in range(max_n)]))\
.select('A', 'B', 'n', f.explode('n_array').alias('col'))\
.where(f.col('col') < f.col('n'))\
.drop('col')\
.show()
#+---+---+---+
#| A| B| n|
#+---+---+---+
#| 1| 2| 1|
#| 2| 9| 1|
#| 3| 8| 2|
#| 3| 8| 2|
#| 4| 1| 1|
#| 5| 3| 3|
#| 5| 3| 3|
#| 5| 3| 3|
#+---+---+---+
However, we are creating a max_n
length array for each row- as opposed to just an n
length array in the udf
solution. It's not immediately clear to me how this will scale vs. udf
for large max_n
, but I suspect the udf
will win out.
Upvotes: 3