Hardik Gupta
Hardik Gupta

Reputation: 4790

Convert string list to binary list in pyspark

I have a dataframe like this

data = [(("ID1", ['October', 'September', 'August'])), (("ID2", ['August', 'June', 'May'])), 
    (("ID3", ['October', 'June']))]
df = spark.createDataFrame(data, ["ID", "MonthList"])
df.show(truncate=False)

+---+----------------------------+
|ID |MonthList                   |
+---+----------------------------+
|ID1|[October, September, August]|
|ID2|[August, June, May]         |
|ID3|[October, June]             |
+---+----------------------------+

I want to compare every row with a default list, such that if the value is present assign 1 else 0

default_month_list = ['October', 'September', 'August', 'July', 'June', 'May']

Hence my expected output is this

+---+----------------------------+------------------+
|ID |MonthList                   |Binary_MonthList  |
+---+----------------------------+------------------+
|ID1|[October, September, August]|[1, 1, 1, 0, 0, 0]|
|ID2|[August, June, May]         |[0, 0, 1, 0, 1, 1]|
|ID3|[October, June]             |[1, 0, 0, 0, 1, 0]|
+---+----------------------------+------------------+

I am able to do this in python, but don't know how to do this in pyspark

Upvotes: 4

Views: 1785

Answers (3)

jxc
jxc

Reputation: 13998

How about using array_contains():

from pyspark.sql.functions import array, array_contains        

df.withColumn('Binary_MonthList', array([array_contains('MonthList', c).astype('int') for c in default_month_list])).show()                                                                                                         
+---+--------------------+------------------+
| ID|           MonthList|  Binary_MonthList|
+---+--------------------+------------------+
|ID1|[October, Septemb...|[1, 1, 1, 0, 0, 0]|
|ID2| [August, June, May]|[0, 0, 1, 0, 1, 1]|
|ID3|     [October, June]|[1, 0, 0, 0, 1, 0]|
+---+--------------------+------------------+

Upvotes: 3

cronoik
cronoik

Reputation: 19395

pissall answer is completely fine. I'm just posting a more general solution that works without an udf and doesn't require you to be aware of possible values.

A CountVectorizer does exactly that what you want. This algorithm adds all distinct values to his dictionary as long as they fullfil certain criteria (e.g. minimum or maximum occurence). You can apply this model on a dataframe and it will return one-hot encoded a sparse vector column (which can be converted to a dense vector column) which represents the items of the given input column.

from pyspark.ml.feature import CountVectorizer

data = [(("ID1", ['October', 'September', 'August']))
        , (("ID2", ['August', 'June', 'May', 'August']))
        , (("ID3", ['October', 'June']))]
df = spark.createDataFrame(data, ["ID", "MonthList"])

df.show(truncate=False)

#binary=True checks only if a item of the dictionary is present and not how often
#vocabSize defines the maximum size of the dictionary
#minDF=1.0 defines in how much rows (1.0 means one row is enough) a values has to be present to be added to the vocabulary
cv = CountVectorizer(inputCol="MonthList", outputCol="Binary_MonthList", vocabSize=12, minDF=1.0, binary=True)

cvModel = cv.fit(df)

df = cvModel.transform(df)

df.show(truncate=False)

cvModel.vocabulary

Output:

+---+----------------------------+
|ID |                  MonthList | 
+---+----------------------------+ 
|ID1|[October, September, August]| 
|ID2| [August, June, May, August]| 
|ID3|            [October, June] | 
+---+----------------------------+ 

+---+----------------------------+-------------------------+ 
|ID |                  MonthList |        Binary_MonthList | 
+---+----------------------------+-------------------------+ 
|ID1|[October, September, August]|(5,[1,2,3],[1.0,1.0,1.0])| 
|ID2|[August, June, May, August] |(5,[0,1,4],[1.0,1.0,1.0])| 
|ID3|[October, June]             |     (5,[0,2],[1.0,1.0]) |
+---+----------------------------+-------------------------+ 

['June', 'August', 'October', 'September', 'May']

Upvotes: 2

pissall
pissall

Reputation: 7399

You can try to use such a udf.

from pyspark.sql.functions import udf, col
from pyspark.sql.types import ArrayType, IntegerType

default_month_list = ['October', 'September', 'August', 'July', 'June', 'May']

def_month_list_func = udf(lambda x: [1 if i in x else 0 for i in default_month_list], ArrayType(IntegerType()))

df = df.withColumn("Binary_MonthList", def_month_list_func(col("MonthList")))

df.show()
# output
+---+--------------------+------------------+
| ID|           MonthList|  Binary_MonthList|
+---+--------------------+------------------+
|ID1|[October, Septemb...|[1, 1, 1, 0, 0, 0]|
|ID2| [August, June, May]|[0, 0, 1, 0, 1, 1]|
|ID3|     [October, June]|[1, 0, 0, 0, 1, 0]|
+---+--------------------+------------------+

Upvotes: 4

Related Questions