Reputation: 107
How can I explode multiple array columns with variable lengths and potential nulls?
My input data looks like this:
+----+------------+--------------+--------------------+
|col1| col2| col3| col4|
+----+------------+--------------+--------------------+
| 1|[id_1, id_2]| [tim, steve]| [apple, pear]|
| 2|[id_3, id_4]| [jenny]| [avocado]|
| 3| null|[tommy, megan]| [apple, strawberry]|
| 4| null| null|[banana, strawberry]|
+----+------------+--------------+--------------------+
I need to explode this such that:
My output should look like this:
+----+----+-----+----------+
|col1|col2|col3 |col4 |
+----+----+-----+----------+
|1 |id_1|tim |apple |
|1 |id_2|steve|pear |
|2 |id_3|jenny|avocado |
|2 |id_4|jenny|avocado |
|3 |null|tommy|apple |
|3 |null|megan|strawberry|
|4 |null|null |banana |
|4 |null|null |strawberry|
+----+----+-----+----------+
I have been able to achieve this using the following code, but I feel like there must be a more straightforward approach:
df = spark.createDataFrame(
[
(1, ["id_1", "id_2"], ["tim", "steve"], ["apple", "pear"]),
(2, ["id_3", "id_4"], ["jenny"], ["avocado"]),
(3, None, ["tommy", "megan"], ["apple", "strawberry"]),
(4, None, None, ["banana", "strawberry"])
],
["col1", "col2", "col3", "col4"]
)
df.createOrReplaceTempView("my_table")
spark.sql("""
with cte as (
SELECT
col1,
col2,
col3,
col4,
greatest(size(col2), size(col3), size(col4)) as max_array_len
FROM my_table
), arrays_extended as (
select
col1,
case
when col2 is null then array_repeat(null, max_array_len)
else col2
end as col2,
case
when size(col3) = 1 then array_repeat(col3[0], max_array_len)
when col3 is null then array_repeat(null, max_array_len)
else col3
end as col3,
case
when size(col4) = 1 then array_repeat(col4[0], max_array_len)
when col4 is null then array_repeat(null, max_array_len)
else col4
end as col4
from cte),
arrays_zipped as (
select *, explode(arrays_zip(col2, col3, col4)) as zipped
from arrays_extended
)
select
col1,
zipped.col2,
zipped.col3,
zipped.col4
from arrays_zipped
""").show(truncate=False)
Upvotes: 0
Views: 762
Reputation: 6644
I used your logic and shortened it a little.
import pyspark.sql.functions as func
arrcols = ['col2', 'col3', 'col4']
data_sdf. \
selectExpr(*['coalesce({0}, array()) as {0}'.format(c) if c in arrcols else c for c in data_sdf.columns]). \
withColumn('max_size', func.greatest(*[func.size(c) for c in arrcols])). \
selectExpr('col1',
*['flatten(array({0}, array_repeat(element_at({0}, -1), max_size-size({0})))) as {0}'.format(c) for c in arrcols]
). \
withColumn('arrzip', func.arrays_zip(*arrcols)). \
selectExpr('col1', 'inline(arrzip)'). \
orderBy('col1', 'col2'). \
show()
# +----+----+-----+----------+
# |col1|col2| col3| col4|
# +----+----+-----+----------+
# | 1|id_1| tim| apple|
# | 1|id_2|steve| pear|
# | 2|id_3|jenny| avocado|
# | 2|id_4|jenny| avocado|
# | 3|null|megan|strawberry|
# | 3|null|tommy| apple|
# | 4|null| null| banana|
# | 4|null| null|strawberry|
# +----+----+-----+----------+
approach steps
array_repeat
on it (similar to your approach)max_size-size({0})
)arrays_zip
) them and explode (using inline()
sql function)the list comprehension in the second selectExpr
generates the following
['flatten(array({0}, array_repeat(element_at({0}, -1), max_size-size({0})))) as {0}'.format(c) for c in arrcols]
# ['flatten(array(col2, array_repeat(element_at(col2, -1), max_size-size(col2)))) as col2',
# 'flatten(array(col3, array_repeat(element_at(col3, -1), max_size-size(col3)))) as col3',
# 'flatten(array(col4, array_repeat(element_at(col4, -1), max_size-size(col4)))) as col4']
if it helps, here are the optimized logical plan and physical plan that spark generated
== Optimized Logical Plan ==
Generate inline(arrzip#363), [1], false, [col2#369, col3#370, col4#371]
+- Project [col1#0L, arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4) AS arrzip#363]
+- Filter (size(arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4), true) > 0)
+- LogicalRDD [col1#0L, col2#1, col3#2, col4#3], false
== Physical Plan ==
Generate inline(arrzip#363), [col1#0L], false, [col2#369, col3#370, col4#371]
+- *(1) Project [col1#0L, arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4) AS arrzip#363]
+- *(1) Filter (size(arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4), true) > 0)
+- *(1) Scan ExistingRDD[col1#0L,col2#1,col3#2,col4#3]
Upvotes: 0
Reputation: 781
After you get max_array_len
, just use sequence function to iterate through the arrays, transform them into a struct, and then explode the resulting array of structs, see below SQL:
spark.sql("""
with cte as (
SELECT
col1,
col2,
col3,
col4,
greatest(size(col2), size(col3), size(col4)) as max_array_len
FROM my_table
)
SELECT inline_outer(
transform(
sequence(0,max_array_len-1), i -> (
col1 as col1,
col2[i] as col2,
coalesce(col3[i], col3[0]) as col3, /* fill null with the first array item of col3 */
coalesce(col4[i], element_at(col4,-1)) as col4 /* fill null with the last array item of col4 */
)
)
)
FROM cte
""").show()
+----+----+-----+----------+
|col1|col2| col3| col4|
+----+----+-----+----------+
| 1|id_1| tim| apple|
| 1|id_2|steve| pear|
| 2|id_3|jenny| avocado|
| 2|id_4|jenny| avocado|
| 3|null|tommy| apple|
| 3|null|megan|strawberry|
| 4|null| null| banana|
| 4|null| null|strawberry|
+----+----+-----+----------+
A similar question here.
Upvotes: 1
Reputation: 3639
You can use an UDF function:
from pyspark.sql import functions as F, types as T
cols_of_interest = [c for c in df.columns if c != 'col1']
@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
def get_sequences(*cols):
"""Equivalent of arrays_zip, but handling different lengths of the arrays.
For shorter array than the maximum length last element is repeated.
"""
# Get the length of the longest array in the row
max_len = max(map(len, filter(lambda x: x, cols)))
return list(zip(*[
# create a list for each column with a length equal to the max_len.
# If the original column has less elements than needed, repeat the last one.
# None values will be filled with a list of Nones with length max_len.
[c[min(i, len(c) - 1)] for i in range(max_len)] if c else [None] * max_len for c in cols
]))
df2 = (
df
.withColumn('temp', F.explode(get_sequences(*cols_of_interest)))
.select('col1',
*[F.col('temp').getItem(i).alias(c) for i, c in enumerate(cols_of_interest)])
)
df2
is the following DataFrame
:
+----+----+-----+----------+
|col1|col2| col3| col4|
+----+----+-----+----------+
| 1|id_1| tim| apple|
| 1|id_2|steve| pear|
| 2|id_3|jenny| avocado|
| 2|id_4|jenny| avocado|
| 3|null|tommy| apple|
| 3|null|megan|strawberry|
| 4|null| null| banana|
| 4|null| null|strawberry|
+----+----+-----+----------+
Upvotes: 0
Reputation: 5032
You can use inline_outer in conjuction with selectExpr
and additionally coalesce
for the first non-null to handle size mismatches within the different arrays
inp_data = [
(1,['id_1', 'id_2'],['tim', 'steve'],['apple', 'pear']),
(2,['id_3', 'id_4'],['jenny'],['avocado']),
(3,None,['tommy','megan'],['apple', 'strawberry']),
(4,None,None,['banana', 'strawberry'])
]
inp_schema = StructType([
StructField('col1',IntegerType(),True)
,StructField('col2',ArrayType(StringType(), True))
,StructField('col3',ArrayType(StringType(), True))
,StructField('col4',ArrayType(StringType(), True))
]
)
sparkDF = sql.createDataFrame(data=inp_data,schema=inp_schema)\
sparkDF.show(truncate=False)
+----+------------+--------------+--------------------+
|col1|col2 |col3 |col4 |
+----+------------+--------------+--------------------+
|1 |[id_1, id_2]|[tim, steve] |[apple, pear] |
|2 |[id_3, id_4]|[jenny] |[avocado] |
|3 |null |[tommy, megan]|[apple, strawberry] |
|4 |null |null |[banana, strawberry]|
+----+------------+--------------+--------------------+
sparkDF.selectExpr("col1"
,"""inline_outer(arrays_zip(
coalesce(col2,array()),
coalesce(col3,array()),
coalesce(col4,array())
)
)""").show(truncate=False)
+----+----+-----+----------+
|col1|0 |1 |2 |
+----+----+-----+----------+
|1 |id_1|tim |apple |
|1 |id_2|steve|pear |
|2 |id_3|jenny|avocado |
|2 |id_4|null |null |
|3 |null|tommy|apple |
|3 |null|megan|strawberry|
|4 |null|null |banana |
|4 |null|null |strawberry|
+----+----+-----+----------+
Upvotes: 2