Prabhu
Prabhu

Reputation: 5466

Pyspark create new column based on other column with multiple condition with list or set

I am trying to create a new column in pyspark data frame. I have data like following

+------+
|letter|
+------+
|     A|
|     C|
|     A|
|     Z|
|     E|
+------+

I want to add a new column based on the given column according to

+------+-----+
|letter|group|
+------+-----+
|     A|   c1|
|     B|   c1|
|     F|   c2|
|     G|   c2|
|     I|   c3|
+------+-----+

There can be multiple categories, with many individual values of letters (around 100, also containing multiple letters)

I have done this with udf, and working well

from pyspark.sql.functions import udf
from pyspark.sql.types import *

c1 = ['A','B','C','D']
c2 = ['E','F','G','H']
c3 = ['I','J','K','L']
...

def l2c(value):
    if value in c1: return 'c1'
    elif value in c2: return 'c2'
    elif value in c3: return 'c3'
    else: return "na"

udf_l2c = udf(l2c, StringType())
data_with_category = data.withColumn("group", udf_l2c("letter"))

Now I am trying to do it without udf. Maybe using when and col. What I have tried is following. It is working, but very long code.

data_with_category = data.withColumn('group', when(col('letter') == 'A' ,'c1')
    .when(col('letter') == 'B', 'c1')
    .when(col('letter') == 'F', 'c2')
    ... 

It is very long and not very good to write new when condition for all possible values of letter. The number of letters can be very large (around 100) in my case. so I tried

data_with_category = data.withColumn('group', when(col('letter') in ['A','B','C','D'] ,'c1')
    .when(col('letter') in ['E','F','G','H'], 'c2')
    .when(col('letter') in ['I','J','K','L'], 'c3')

But it returns error. How can I solve this?

Upvotes: 4

Views: 6155

Answers (2)

murtihash
murtihash

Reputation: 8410

Use isin.

c1 = ['A','B','C','D']
c2 =['E','F','G','H']
c3 =['I','J','K','L']

df.withColumn("group", F.when(F.col("letter").isin(c1),F.lit('c1'))\
                        .when(F.col("letter").isin(c2),F.lit('c2'))\
                        .when(F.col("letter").isin(c3),F.lit('c3'))).show()

#+------+-----+
#|letter|group|
#+------+-----+
#|     A|   c1|
#|     B|   c1|
#|     F|   c2|
#|     G|   c2|
#|     I|   c3|
#+------+-----+

Upvotes: 6

Phạm Ngọc Quý
Phạm Ngọc Quý

Reputation: 329

you can try to using udf, for example:

say_hello_udf = udf(lambda name: say_hello(name), StringType())
df = spark.createDataFrame([("Rick,"),("Morty,")], ["name"])
df.withColumn("greetings", say_hello_udf(col("name")).show()

or

@udf(returnType=StringType())
def say_hello(name):
   return f"Hello {name}"
df.withColumn("greetings", say_hello(col("name")).show()

Upvotes: 2

Related Questions