ar_mm18
ar_mm18

Reputation: 465

How to explode array column to produce a boolean column in PySpark

I have a data frame like this:

+------------+-----------------+------------------------------------+
| Name       |   Age           | Answers                            |
+------------+-----------------+------------------------------------+
| Maria      | 23              | [apple, mango, orange, banana]     | 
| John       | 55              | [apple, orange, banana]            |
| Brad       | 44              | [banana]                           |
| Alex       | 55              | [apple, mango, orange, banana]     |
+------------+-----------------+------------------------------------+

The "Answers" column contains an array of elements.

My expected output:

+-----+---+--------+-------+                                                              
| Name|Age|  answer| value |
+-----+---+--------+-------+
|Maria| 23|   apple| True  |
|Maria| 23|   mango| True  |
|Maria| 23|  orange| True  |
|Maria| 23|  banana| True  |
| John| 55|   apple| True  |
| John| 55|   mango| False |
| John| 55|  orange| True  |
| John| 55|  banana| True  |
| Brad| 44|   apple| False |
| Brad| 44|   mango| False |
| Brad| 44|  orange| False |
| Brad| 44|  banana| True  |
|Alex | 55|   apple| True  |
|Alex | 55|   mango| True  |
|Alex | 55|  orange| True  |
|Alex | 55|  banana| True  |
+-----+---+--------+-------+

How can I explode the "Answers" column in such a way that I would get the "value" column with True or False based on the array?

For example,

| John| 55|   mango| False |

there is no "mango" in John's answer. Hence the value is false. Similarly for Brad there will be three false rows.

Upvotes: 1

Views: 210

Answers (2)

samkart
samkart

Reputation: 6644

an approach using transform and arrays_zip functions

data_sdf. \
    withColumnRenamed('answers', 'ans_per_name'). \
    withColumn('answers', 
               func.array_distinct(func.flatten(func.collect_set('ans_per_name').over(wd.partitionBy())))
               ). \
    withColumn('value', 
               func.expr('transform(answers, x -> array_contains(ans_per_name, x))')
               ). \
    withColumn('ans_val_struct', func.arrays_zip('answers', 'value')). \
    selectExpr('name', 'age', 'inline(ans_val_struct)'). \
    show(truncate=False)

# +-----+---+-------+-----+
# |name |age|answers|value|
# +-----+---+-------+-----+
# |Maria|23 |apple  |true |
# |Maria|23 |orange |true |
# |Maria|23 |banana |true |
# |Maria|23 |mango  |true |
# |John |55 |apple  |true |
# |John |55 |orange |true |
# |John |55 |banana |true |
# |John |55 |mango  |false|
# |Brad |44 |apple  |false|
# |Brad |44 |orange |false|
# |Brad |44 |banana |true |
# |Brad |44 |mango  |false|
# |Alex |55 |apple  |true |
# |Alex |55 |orange |true |
# |Alex |55 |banana |true |
# |Alex |55 |mango  |true |
# +-----+---+-------+-----+
  • the idea is to get all answers against each of the names. the collect_set along with flatten and array_distinct does that.
  • the transform checks each of the collected answers against the answers array that was previously present for each name. if the element was present, it is marked True.
  • arrays_zip will zip 2 arrays to create an array of structs where the Nth struct will have the Nth elements from each of the arrays.
  • inline sql function helps to explode and create new columns from struct fields

Upvotes: 0

ZygD
ZygD

Reputation: 24386

Before exploding, you could collect all possible values in "Answers" column. Add them to the dataframe, explode and select required columns.

Input:

from pyspark.sql import functions as F
df = spark.createDataFrame(
    [('Maria', 23, ['apple', 'mango', 'orange', 'banana']),
     ('John', 55, ['apple', 'orange', 'banana']),
     ('Brad', 44, ['banana']),
     ('Alex', 55, ['apple', 'mango', 'orange', 'banana'])],
    ['Name', 'Age', 'Answers'])

Script:

unique_answers = set(df.agg(F.flatten(F.collect_set('Answers'))).head()[0])
df = df.withColumn('answer', F.explode(F.array([F.lit(x) for x in unique_answers])))
df = df.select(
    'Name', 'Age', 'answer',
    F.exists('Answers', lambda x: x == F.col('answer')).alias('value')
    *[c for c in df.columns if c not in {'Name', 'Age', 'Answers', 'answer'}]
)
df.show()
# +-----+---+------+-----+
# | Name|Age|answer|value|
# +-----+---+------+-----+
# |Maria| 23|orange| true|
# |Maria| 23| mango| true|
# |Maria| 23| apple| true|
# |Maria| 23|banana| true|
# | John| 55|orange| true|
# | John| 55| mango|false|
# | John| 55| apple| true|
# | John| 55|banana| true|
# | Brad| 44|orange|false|
# | Brad| 44| mango|false|
# | Brad| 44| apple|false|
# | Brad| 44|banana| true|
# | Alex| 55|orange| true|
# | Alex| 55| mango| true|
# | Alex| 55| apple| true|
# | Alex| 55|banana| true|
# +-----+---+------+-----+

Upvotes: 1

Related Questions