Brian Behe
Brian Behe

Reputation: 542

Pyspark Dataframe - Map Strings to Numerics

I'm looking for a way to convert a given column of data, in this case strings, and convert them into a numeric representation. For example, I have a dataframe of strings with values:

+------------+
|    level   |
+------------+
|      Medium|
|      Medium|
|      Medium|
|        High|
|      Medium|
|      Medium|
|         Low|
|         Low|
|        High|
|         Low|
|         Low|

And I want to create a new column where these values get converted to:

"High"= 1, "Medium" = 2, "Low" = 3

+------------+
|   level_num|
+------------+
|           2|
|           2|
|           2|
|           1|
|           2|
|           2|
|           3|
|           3|
|           1|
|           3|
|           3|

I've tried defining a function and doing a foreach over the dataframe like so:

def f(x): 
    if(x == 'Medium'):
       return 2
    elif(x == "Low"):
       return 3
    else:
       return 1

 a = df.select("level").rdd.foreach(f)

But this returns a "None" type. Thoughts? Thanks for the help as always!

Upvotes: 6

Views: 9951

Answers (2)

abiratsis
abiratsis

Reputation: 7316

An alternative would be to use a Python dictionary to represent the map for Spark >= 2.4.

Then use array and map_from_arrays Spark functions to implement a key-based search mechanism for filling in the level_num field:

from pyspark.sql.functions import lit, map_from_arrays, array

_dict = {"High":1, "Medium":2, "Low":3}

df = spark.createDataFrame([
["Medium"], ["Medium"], ["Medium"], ["High"], ["Medium"], ["Medium"], ["Low"], ["Low"], ["High"]
], ["level"])

keys = array(list(map(lit, _dict.keys()))) # or alternatively [lit(k) for k in _dict.keys()]
values = array(list(map(lit, _dict.values())))
_map = map_from_arrays(keys, values)

df.withColumn("level_num", _map.getItem(col("level"))) # or element_at(_map, col("level"))

# +------+---------+
# | level|level_num|
# +------+---------+
# |Medium|        2|
# |Medium|        2|
# |Medium|        2|
# |  High|        1|
# |Medium|        2|
# |Medium|        2|
# |   Low|        3|
# |   Low|        3|
# |  High|        1|
# +------+---------+

Upvotes: 1

desertnaut
desertnaut

Reputation: 60318

You can certainly do this along the lines you have been trying - you'll need a map operation instead of foreach.

spark.version
# u'2.2.0'

from pyspark.sql import Row
# toy data:
df = spark.createDataFrame([Row("Medium"),
                              Row("High"),
                              Row("High"),
                              Row("Low")
                             ],
                              ["level"])
df.show()
# +------+ 
# | level|
# +------+
# |Medium|
# |  High|
# |  High|
# |   Low|
# +------+

Using your f(x) with these toy data, we get:

df.select("level").rdd.map(lambda x: f(x[0])).collect()
# [2, 1, 1, 3]

And one more map will give you a dataframe:

df.select("level").rdd.map(lambda x: f(x[0])).map(lambda x: Row(x)).toDF(["level_num"]).show()
# +---------+ 
# |level_num|
# +---------+
# |        2|
# |        1|
# |        1| 
# |        3|
# +---------+

But it would be preferable to do it without invoking a temporary intermediate RDD, using the dataframe function when instead of your f(x):

from pyspark.sql.functions import col, when

df.withColumn("level_num", when(col("level")=='Medium', 2).when(col("level")=='Low', 3).otherwise(1)).show()
# +------+---------+ 
# | level|level_num|
# +------+---------+
# |Medium|        2|
# |  High|        1| 
# |  High|        1|
# |   Low|        3| 
# +------+---------+    

Upvotes: 8

Related Questions