AMR
AMR

Reputation: 85

UDF lookup mapping a pyspark dataframe column

I have a pyspark.sql.dataframe.DataFrame object df which contains Continent and Country code. I also have a dictionary of dictionary dicts which contains the lookup value for each column.

import pyspark.sql.functions as F
import pyspark.sql.types as T


df = sc.parallelize([('A1','JP'),('A1','CH'),('A2','CA'),
   ('A2','US')]).toDF(['Continent','Country'])

dicts = sc.broadcast(dict([('Country', dict([
                          ('US', 'USA'), 
                          ('JP', 'Japan'),
                          ('CA', 'Canada'),
                          ('CH', 'China')
              ])),
              ('Continent', dict([
                          ('A1','Asia'), 
                          ('A2','America')])
              )
              ]))

+---------+-------+
|Continent|Country|
+---------+-------+
|       A1|     JP|
|       A1|     CH|
|       A2|     CA|
|       A2|     US|
+---------+-------+

I want to replace both Country and Continent into it lookup value as I have try:

preprocess_request = F.udf(lambda colname, key: 
                       dicts.value[colname].get[key], 
                      T.StringType())
df.withColumn('Continent', preprocess_request('Continent', F.col('Continent')))\
.withColumn('Country', preprocess_request('Country', F.col('Country')))\
.display()

but got me error said object is not subscriptable.

What I expect exactly like this:

+---------+-------+
|Continent|Country|
+---------+-------+
|     Asia|  Japan|
|     Asia|  China|
|  America| Canada|
|  America|    USA|
+---------+-------+

Upvotes: 4

Views: 2129

Answers (2)

wwnde
wwnde

Reputation: 26676

I would use a pandas udf instead of a plain udf. pandas udfs are vectorized.

Option 1

def map_dict(iterator: Iterator[pd.DataFrame]) -> Iterator[pd.DataFrame]:
    for pdf in iterator:
      Continent=pdf.Continent
      Country=pdf.Country
      yield pdf.assign(Continent=Continent.map(dicts.value['Continent']),
       Country=Country.map(dicts.value['Country']))

df.mapInPandas(map_dict, schema=df.schema).show()

Option 2 Please note though this is likely to incur a shuffle.

from typing import Iterator, Tuple
import pandas as pd
from pyspark.sql.functions import pandas_udf       


def map_dict(pdf: pd.DataFrame) -> pd.DataFrame:
  Continent=pdf.Continent
  Country=pdf.Country
  return pdf.assign(Continent=Continent.map(dicts.value['Continent']),
       Country=Country.map(dicts.value['Country']))

df.groupby("Continent","Country").applyInPandas(map_dict, schema=df.schema).show()

+---+---------+-------+
| id|Continent|Country|
+---+---------+-------+
|  2|     Asia|  China|
|  1|     Asia|  Japan|
|  3|  America| Canada|
|  4|  America|    USA|
+---+---------+-------+

Upvotes: 0

Alex Ott
Alex Ott

Reputation: 87249

There is a problem with your arguments to a function - when you specify 'Continent' - it's treated as a column name, not a fixed value, so when your UDF is called, the value of this column is passed, not the word Continent. To fix this, you need to wrap Continent and Country into F.lit:

preprocess_request = F.udf(lambda colname, key: 
                       dicts.value.get(colname, {}).get(key), 
                      T.StringType())
df.withColumn('Continent', preprocess_request(F.lit('Continent'), F.col('Continent')))\
.withColumn('Country', preprocess_request(F.lit('Country'), F.col('Country')))\
.display()

with it it gives correct result:

+---------+-------+
|Continent|Country|
+---------+-------+
|     Asia|  Japan|
|     Asia|  China|
|  America| Canada|
|  America|    USA|
+---------+-------+

But really you don't need UDF for that, as it's very slow due serialization overhead. It could be much faster if you use native PySpark APIs and represent dictionaries as Spark literal. Something like this:

continents = F.expr("map('A1','Asia', 'A2','America')")
countries = F.expr("map('US', 'USA', 'JP', 'Japan', 'CA', 'Canada', 'CH', 'China')")
df.withColumn('Continent', continents[F.col('Continent')])\
.withColumn('Country', countries[F.col('Country')])\
.show()

gives you the same answer, but should be much faster:

+---------+-------+
|Continent|Country|
+---------+-------+
|     Asia|  Japan|
|     Asia|  China|
|  America| Canada|
|  America|    USA|
+---------+-------+

Upvotes: 1

Related Questions