Cheryl
Cheryl

Reputation: 77

How to create new column based on values in array column in Pyspark

I have the following dataframe with codes which represent products:

testdata = [(0, ['a','b','d']), (1, ['c']), (2, ['d','e'])]
df = spark.createDataFrame(testdata, ['id', 'codes'])
df.show()
+---+---------+
| id|    codes|
+---+---------+
|  0|[a, b, d]|
|  1|      [c]|
|  2|   [d, e]|
+---+---------+

Let's say codes a and b represent t-shirts and code c represents sweaters.

tshirts = ['a','b']
sweaters = ['c']

How can I create a column label which checks whether these codes are in the array column and returns the name of the product. Like so:

+---+---------+--------+
| id|    codes|   label|
+---+---------+--------+
|  0|[a, b, d]| tshirts|
|  1|      [c]|sweaters|
|  2|   [d, e]|    none|
+---+---------+--------+

I have already tried a lot of things, amongst others the following which does not work:

codes = {
    'tshirts': ['a','b'],
    'sweaters': ['c']
}

def any_isin(ref_values, array_to_search):
    for key, values in ref_values.items():
        if any(item in array_to_search for item in values):
            return key
        else:
            return 'none'

any_isin_udf = lambda ref_values: (F.udf(lambda array_to_search: any_isin_mod(ref_values, array_to_search), StringType()))

df_labeled = df.withColumn('label', any_isin_udf(codes)(F.col('codes')))

df_labeled.show()
+---+---------+-------+
| id|    codes|  label|
+---+---------+-------+
|  0|[a, b, d]|tshirts|
|  1|      [c]|   none|
|  2|   [d, e]|   none|
+---+---------+-------+

Upvotes: 4

Views: 4953

Answers (2)

pault
pault

Reputation: 43494

A non-udf method such as @user10055507's answer using pyspark.sql.functions.array_contains() is preferred, but here is an explanation of what's causing your code to fail:

The error is that you are calling return inside the loop, so you will never iterate past the first key. Here is a way to modify your udf to get the desired result:

import pyspark.sql.functions as f

codes = {
    'tshirts': ['a','b'],
    'sweaters': ['c']
}

def any_isin(ref_values, array_to_search):
    label = 'none'
    for key, values in ref_values.items():
        if any(item in array_to_search for item in values):
            label=key
            break
    return label

any_isin_udf = lambda ref_values: (
    f.udf(lambda array_to_search: any_isin(ref_values, array_to_search), StringType())
)

df_labeled = df.withColumn('label', any_isin_udf(codes)(f.col('codes')))

df_labeled.show()
#+---+---------+--------+
#| id|    codes|   label|
#+---+---------+--------+
#|  0|[a, b, d]| tshirts|
#|  1|      [c]|sweaters|
#|  2|   [d, e]|    none|
#+---+---------+--------+

Update

Here is an alternative non-udf method using a join:

First turn the codes dictionary into a table:

import pyspark.sql.functions as f
from itertools import chain

codes_df = spark.createDataFrame(
    list(chain.from_iterable(zip([a]*len(b), b) for a, b in codes.items())),
    ["label", "code"]
)
codes_df.show()
#+--------+----+
#|   label|code|
#+--------+----+
#| tshirts|   a|
#| tshirts|   b|
#|sweaters|   c|
#+--------+----+

Now do a left join of df and codes_df on a boolean indicating if the codes array contains the code:

df.alias('l')\
    .join(
        codes_df.alias('r'),
        how='left',
        on=f.expr('array_contains(l.codes, r.code)')
    )\
    .select('id', 'codes', 'label')\
    .distinct()\
    .show()
#+---+---------+--------+
#| id|    codes|   label|
#+---+---------+--------+
#|  2|   [d, e]|    null|
#|  0|[a, b, d]| tshirts|
#|  1|      [c]|sweaters|
#+---+---------+--------+

Upvotes: 0

Aaron Makubuya
Aaron Makubuya

Reputation: 1007

I would expression with array_contains. Let's define input as a dict:

from pyspark.sql.functions import expr, lit, when
from operator import and_
from functools import reduce

label_map = {"tshirts": ["a", "b"], "sweaters": ["c"]}

Next generate expression:

expression_map = {
   label: reduce(and_, [expr("array_contains(codes, '{}')".format(code))
   for code in codes]) for label, codes in label_map.items()
}

Finally reduce it with CASE ... WHEN:

label = reduce(
    lambda acc, kv: when(kv[1], lit(kv[0])).otherwise(acc),
    expression_map.items(), 
    lit(None).cast("string")
).alias("label")

Result:

df.withColumn("label", label).show()
# +---+---------+--------+                                                        
# | id|    codes|   label|
# +---+---------+--------+
# |  0|[a, b, d]| tshirts|
# |  1|      [c]|sweaters|
# |  2|   [d, e]|    null|
# +---+---------+--------+

Upvotes: 2

Related Questions