Quynh-Mai Chu
Quynh-Mai Chu

Reputation: 155

Exploding multiple array columns in spark for a changing input schema in PySpark

I have the same situation as stated in this question.

df = spark.createDataFrame(
    [(1, "xx", [10, 20], ["a", "b"], ["p", "q"]),
     (2, "yy", [30, 40], ["c", "d"], ["r", "s"]),
     (3, "zz",     None, ["f", "g"], ["e", "k"])],
    ["c1", "c2", "a1", "a2", "a3"])
df.show()
# +---+---+--------+------+------+
# | c1| c2|      a1|    a2|    a3|
# +---+---+--------+------+------+
# |  1| xx|[10, 20]|[a, b]|[p, q]|
# |  2| yy|[30, 40]|[c, d]|[r, s]|
# |  3| zz|    null|[f, g]|[e, k]|
# +---+---+--------+------+------+

I can't figure out a way to explode it correctly in PySpark. How I can achieve this result?

+---+---+----+---+---+
| c1| c2|  a1| a2| a3|
+---+---+----+---+---+
|  1| xx|  10|  a|  p|
|  1| xx|  20|  b|  q|
|  2| yy|  30|  c|  r|
|  2| yy|  40|  d|  s|
|  3| zz|null|  f|  e|
|  3| zz|null|  g|  k|
+---+---+----+---+---+

Upvotes: 1

Views: 1127

Answers (1)

ZygD
ZygD

Reputation: 24458

The following should do it for dynamic number of array columns.

Spark 3:

from pyspark.sql import functions as F

arr_cols = [c[0] for c in df.dtypes if c[1][:5] == "array"]
df = df.withColumn(
    "arr_of_struct",
    F.arrays_zip(*[F.coalesce(c, F.array(F.lit(None))).alias(c) for c in arr_cols])
).select(
    *[c for c in df.columns if c not in arr_cols],
    F.expr("inline(arr_of_struct)")
)

df.show()
# +---+---+----+---+---+
# | c1| c2|  a1| a2| a3|
# +---+---+----+---+---+
# |  1| xx|  10|  a|  p|
# |  1| xx|  20|  b|  q|
# |  2| yy|  30|  c|  r|
# |  2| yy|  40|  d|  s|
# |  3| zz|null|  f|  e|
# |  3| zz|null|  g|  k|
# +---+---+----+---+---+

Spark 2:

from pyspark.sql import functions as F

arr_cols = [c[0] for c in df.dtypes if c[1][:5] == "array"]
df = df.withColumn(
    "my_struct",
    F.explode(F.arrays_zip(*[F.coalesce(c, F.array(F.lit(None))) for c in arr_cols]))
).select(
    *[c for c in df.columns if c not in arr_cols],
    *[F.col(f"my_struct.{i}").alias(c) for i, c in enumerate(arr_cols)]
)

Upvotes: 1

Related Questions