Dumbledore__
Dumbledore__

Reputation: 107

Explode multiple array columns with variable lengths

How can I explode multiple array columns with variable lengths and potential nulls?

My input data looks like this:

+----+------------+--------------+--------------------+
|col1|        col2|          col3|                col4|
+----+------------+--------------+--------------------+
|   1|[id_1, id_2]|  [tim, steve]|       [apple, pear]|
|   2|[id_3, id_4]|       [jenny]|           [avocado]|
|   3|        null|[tommy, megan]| [apple, strawberry]|
|   4|        null|          null|[banana, strawberry]|
+----+------------+--------------+--------------------+

I need to explode this such that:

  1. Array items with the same index are mapped to the same row
  2. If there is only 1 entry in a column, it applies to every exploded row
  3. If an array is null, it applies to every row

My output should look like this:

+----+----+-----+----------+
|col1|col2|col3 |col4      |
+----+----+-----+----------+
|1   |id_1|tim  |apple     |
|1   |id_2|steve|pear      |
|2   |id_3|jenny|avocado   |
|2   |id_4|jenny|avocado   |
|3   |null|tommy|apple     |
|3   |null|megan|strawberry|
|4   |null|null |banana    |
|4   |null|null |strawberry|
+----+----+-----+----------+

I have been able to achieve this using the following code, but I feel like there must be a more straightforward approach:

df = spark.createDataFrame(
    [
        (1, ["id_1", "id_2"], ["tim", "steve"], ["apple", "pear"]),
        (2, ["id_3", "id_4"], ["jenny"], ["avocado"]),
        (3, None, ["tommy", "megan"], ["apple", "strawberry"]),
        (4, None, None, ["banana", "strawberry"])
    ],
    ["col1", "col2", "col3", "col4"]
)

df.createOrReplaceTempView("my_table")


spark.sql("""
with cte as (
  SELECT
  col1,
  col2,
  col3,
  col4,
  greatest(size(col2), size(col3), size(col4)) as max_array_len
  FROM my_table
), arrays_extended as (
select 
col1,
case 
  when col2 is null then array_repeat(null, max_array_len) 
  else col2 
end as col2, 
case
  when size(col3) = 1 then array_repeat(col3[0], max_array_len)
  when col3 is null then array_repeat(null, max_array_len)
  else col3
end as col3,
case
  when size(col4) = 1 then array_repeat(col4[0], max_array_len)
  when col4 is null then array_repeat(null, max_array_len)
  else col4
end as col4
from cte), 
arrays_zipped as (
select *, explode(arrays_zip(col2, col3, col4)) as zipped
from arrays_extended
)
select 
  col1,
  zipped.col2,
  zipped.col3,
  zipped.col4
from arrays_zipped
""").show(truncate=False)


Upvotes: 0

Views: 762

Answers (4)

samkart
samkart

Reputation: 6644

I used your logic and shortened it a little.

import pyspark.sql.functions as func

arrcols = ['col2', 'col3', 'col4']

data_sdf. \
    selectExpr(*['coalesce({0}, array()) as {0}'.format(c) if c in arrcols else c for c in data_sdf.columns]). \
    withColumn('max_size', func.greatest(*[func.size(c) for c in arrcols])). \
    selectExpr('col1', 
               *['flatten(array({0}, array_repeat(element_at({0}, -1), max_size-size({0})))) as {0}'.format(c) for c in arrcols]
               ). \
    withColumn('arrzip', func.arrays_zip(*arrcols)). \
    selectExpr('col1', 'inline(arrzip)'). \
    orderBy('col1', 'col2'). \
    show()

# +----+----+-----+----------+
# |col1|col2| col3|      col4|
# +----+----+-----+----------+
# |   1|id_1|  tim|     apple|
# |   1|id_2|steve|      pear|
# |   2|id_3|jenny|   avocado|
# |   2|id_4|jenny|   avocado|
# |   3|null|megan|strawberry|
# |   3|null|tommy|     apple|
# |   4|null| null|    banana|
# |   4|null| null|strawberry|
# +----+----+-----+----------+

approach steps

  • fill nulls with empty arrays, and take the maximum size within all the array columns
  • add elements to arrays that are smaller in size compared to others
    • i took the last element of the array and used array_repeat on it (similar to your approach)
    • the number of times to be repeated is calculated by checking the max size against the size of the array being worked on (max_size-size({0}))
  • with the aforementioned steps, you will now have same number of elements in each of the array column which enables you to zip (arrays_zip) them and explode (using inline() sql function)

the list comprehension in the second selectExpr generates the following

['flatten(array({0}, array_repeat(element_at({0}, -1), max_size-size({0})))) as {0}'.format(c) for c in arrcols]

# ['flatten(array(col2, array_repeat(element_at(col2, -1), max_size-size(col2)))) as col2',
#  'flatten(array(col3, array_repeat(element_at(col3, -1), max_size-size(col3)))) as col3',
#  'flatten(array(col4, array_repeat(element_at(col4, -1), max_size-size(col4)))) as col4']

if it helps, here are the optimized logical plan and physical plan that spark generated

== Optimized Logical Plan ==
Generate inline(arrzip#363), [1], false, [col2#369, col3#370, col4#371]
+- Project [col1#0L, arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4) AS arrzip#363]
   +- Filter (size(arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4), true) > 0)
      +- LogicalRDD [col1#0L, col2#1, col3#2, col4#3], false

== Physical Plan ==
Generate inline(arrzip#363), [col1#0L], false, [col2#369, col3#370, col4#371]
+- *(1) Project [col1#0L, arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4) AS arrzip#363]
   +- *(1) Filter (size(arrays_zip(flatten(array(coalesce(col2#1, []), array_repeat(element_at(coalesce(col2#1, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col2#1, []), true))))), flatten(array(coalesce(col3#2, []), array_repeat(element_at(coalesce(col3#2, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col3#2, []), true))))), flatten(array(coalesce(col4#3, []), array_repeat(element_at(coalesce(col4#3, []), -1, false), (greatest(size(coalesce(col2#1, []), true), size(coalesce(col3#2, []), true), size(coalesce(col4#3, []), true)) - size(coalesce(col4#3, []), true))))), col2, col3, col4), true) > 0)
      +- *(1) Scan ExistingRDD[col1#0L,col2#1,col3#2,col4#3]

Upvotes: 0

lihao
lihao

Reputation: 781

After you get max_array_len, just use sequence function to iterate through the arrays, transform them into a struct, and then explode the resulting array of structs, see below SQL:

spark.sql("""
  with cte as (
    SELECT      
      col1,         
      col2,
      col3,
      col4,
      greatest(size(col2), size(col3), size(col4)) as max_array_len
    FROM my_table
  )
  SELECT inline_outer(
           transform(
             sequence(0,max_array_len-1), i -> (
               col1 as col1,
               col2[i] as col2,
               coalesce(col3[i], col3[0]) as col3,             /* fill null with the first array item of col3 */
               coalesce(col4[i], element_at(col4,-1)) as col4  /* fill null with the last array item of col4 */
             )
           )
         )
  FROM cte
""").show()
+----+----+-----+----------+
|col1|col2| col3|      col4|
+----+----+-----+----------+
|   1|id_1|  tim|     apple|
|   1|id_2|steve|      pear|
|   2|id_3|jenny|   avocado|
|   2|id_4|jenny|   avocado|
|   3|null|tommy|     apple|
|   3|null|megan|strawberry|
|   4|null| null|    banana|
|   4|null| null|strawberry|
+----+----+-----+----------+

A similar question here.

Upvotes: 1

PieCot
PieCot

Reputation: 3639

You can use an UDF function:

from pyspark.sql import functions as F, types as T

cols_of_interest = [c for c in df.columns if c != 'col1']

@F.udf(returnType=T.ArrayType(T.ArrayType(T.StringType())))
def get_sequences(*cols):
    """Equivalent of arrays_zip, but handling different lengths of the arrays.
       For shorter array than the maximum length last element is repeated.
    """
    
    # Get the length of the longest array in the row
    max_len = max(map(len, filter(lambda x: x, cols)))
    
    return list(zip(*[
        # create a list for each column with a length equal to the max_len.
        # If the original column has less elements than needed, repeat the last one.
        # None values will be filled with a list of Nones with length max_len.
        [c[min(i, len(c) - 1)]  for i in range(max_len)] if c else [None] * max_len for c in cols
    ]))

df2 = (
    df
    .withColumn('temp', F.explode(get_sequences(*cols_of_interest)))
    .select('col1', 
            *[F.col('temp').getItem(i).alias(c) for i, c in enumerate(cols_of_interest)])
)

df2 is the following DataFrame:

+----+----+-----+----------+
|col1|col2| col3|      col4|
+----+----+-----+----------+
|   1|id_1|  tim|     apple|
|   1|id_2|steve|      pear|
|   2|id_3|jenny|   avocado|
|   2|id_4|jenny|   avocado|
|   3|null|tommy|     apple|
|   3|null|megan|strawberry|
|   4|null| null|    banana|
|   4|null| null|strawberry|
+----+----+-----+----------+

Upvotes: 0

Vaebhav
Vaebhav

Reputation: 5032

You can use inline_outer in conjuction with selectExpr and additionally coalesce for the first non-null to handle size mismatches within the different arrays

Data Preparation

inp_data = [
    (1,['id_1', 'id_2'],['tim', 'steve'],['apple', 'pear']),
    (2,['id_3', 'id_4'],['jenny'],['avocado']),
    (3,None,['tommy','megan'],['apple', 'strawberry']),
    (4,None,None,['banana', 'strawberry'])
]


inp_schema = StructType([
                      StructField('col1',IntegerType(),True)
                     ,StructField('col2',ArrayType(StringType(), True))
                     ,StructField('col3',ArrayType(StringType(), True))
                     ,StructField('col4',ArrayType(StringType(), True))
                    ]
                   )

sparkDF = sql.createDataFrame(data=inp_data,schema=inp_schema)\


sparkDF.show(truncate=False)

+----+------------+--------------+--------------------+
|col1|col2        |col3          |col4                |
+----+------------+--------------+--------------------+
|1   |[id_1, id_2]|[tim, steve]  |[apple, pear]       |
|2   |[id_3, id_4]|[jenny]       |[avocado]           |
|3   |null        |[tommy, megan]|[apple, strawberry] |
|4   |null        |null          |[banana, strawberry]|
+----+------------+--------------+--------------------+

Inline Outer

sparkDF.selectExpr("col1"
                   ,"""inline_outer(arrays_zip(
                                       coalesce(col2,array()),
                                       coalesce(col3,array()),
                                       coalesce(col4,array())
                                    )
                )""").show(truncate=False)

+----+----+-----+----------+
|col1|0   |1    |2         |
+----+----+-----+----------+
|1   |id_1|tim  |apple     |
|1   |id_2|steve|pear      |
|2   |id_3|jenny|avocado   |
|2   |id_4|null |null      |
|3   |null|tommy|apple     |
|3   |null|megan|strawberry|
|4   |null|null |banana    |
|4   |null|null |strawberry|
+----+----+-----+----------+

Upvotes: 2

Related Questions