Reputation: 542
I'm looking for a way to convert a given column of data, in this case strings, and convert them into a numeric representation. For example, I have a dataframe of strings with values:
+------------+
| level |
+------------+
| Medium|
| Medium|
| Medium|
| High|
| Medium|
| Medium|
| Low|
| Low|
| High|
| Low|
| Low|
And I want to create a new column where these values get converted to:
"High"= 1, "Medium" = 2, "Low" = 3
+------------+
| level_num|
+------------+
| 2|
| 2|
| 2|
| 1|
| 2|
| 2|
| 3|
| 3|
| 1|
| 3|
| 3|
I've tried defining a function and doing a foreach over the dataframe like so:
def f(x):
if(x == 'Medium'):
return 2
elif(x == "Low"):
return 3
else:
return 1
a = df.select("level").rdd.foreach(f)
But this returns a "None" type. Thoughts? Thanks for the help as always!
Upvotes: 6
Views: 9951
Reputation: 7316
An alternative would be to use a Python dictionary to represent the map for Spark >= 2.4.
Then use array and map_from_arrays Spark functions to implement a key-based search mechanism for filling in the level_num
field:
from pyspark.sql.functions import lit, map_from_arrays, array
_dict = {"High":1, "Medium":2, "Low":3}
df = spark.createDataFrame([
["Medium"], ["Medium"], ["Medium"], ["High"], ["Medium"], ["Medium"], ["Low"], ["Low"], ["High"]
], ["level"])
keys = array(list(map(lit, _dict.keys()))) # or alternatively [lit(k) for k in _dict.keys()]
values = array(list(map(lit, _dict.values())))
_map = map_from_arrays(keys, values)
df.withColumn("level_num", _map.getItem(col("level"))) # or element_at(_map, col("level"))
# +------+---------+
# | level|level_num|
# +------+---------+
# |Medium| 2|
# |Medium| 2|
# |Medium| 2|
# | High| 1|
# |Medium| 2|
# |Medium| 2|
# | Low| 3|
# | Low| 3|
# | High| 1|
# +------+---------+
Upvotes: 1
Reputation: 60318
You can certainly do this along the lines you have been trying - you'll need a map
operation instead of foreach
.
spark.version
# u'2.2.0'
from pyspark.sql import Row
# toy data:
df = spark.createDataFrame([Row("Medium"),
Row("High"),
Row("High"),
Row("Low")
],
["level"])
df.show()
# +------+
# | level|
# +------+
# |Medium|
# | High|
# | High|
# | Low|
# +------+
Using your f(x)
with these toy data, we get:
df.select("level").rdd.map(lambda x: f(x[0])).collect()
# [2, 1, 1, 3]
And one more map
will give you a dataframe:
df.select("level").rdd.map(lambda x: f(x[0])).map(lambda x: Row(x)).toDF(["level_num"]).show()
# +---------+
# |level_num|
# +---------+
# | 2|
# | 1|
# | 1|
# | 3|
# +---------+
But it would be preferable to do it without invoking a temporary intermediate RDD, using the dataframe function when
instead of your f(x)
:
from pyspark.sql.functions import col, when
df.withColumn("level_num", when(col("level")=='Medium', 2).when(col("level")=='Low', 3).otherwise(1)).show()
# +------+---------+
# | level|level_num|
# +------+---------+
# |Medium| 2|
# | High| 1|
# | High| 1|
# | Low| 3|
# +------+---------+
Upvotes: 8