Matt
Matt

Reputation: 179

Converting apply from pandas to a pandas_udf

How can I convert the following sample code to a pandas_udf:

def calculate_courses_final_df(this_row):
   some code that applies to each row of the data

df_contracts_courses.apply(lambda x: calculate_courses_final_df(x), axis=1)

df_contracts_courses is a pandas dataframe (not grouped) and the function applies to each row of the pandas dataframe and generates an output. Ideally, I will have df_contracts_courses as a spark dataframe and apply the pandas_udf function to it directly.

I tried writing adding a monotonically increasing ID to the spark dataframe and grouping by that ID and applying a panadas udf to the grouped by dataframe. it works but it is really slow compared to the pandas function. Is there a more efficient way?

here is what I tried:

from pyspark.sql.functions import monotonically_increasing_id
schema = StructType([StructField('WD_Customer_ID', StringType(), True),
                     StructField('Avg_Num_Training', DoubleType(), True)])

@pandas_udf(schema, PandasUDFType.GROUPED_MAP)
def calculate_courses_final_df(this_row):
    some code
df_contracts_courses = df_contracts_courses.withColumn("id", monotonically_increasing_id())
df_xu_final_rows_list = df_contracts_courses.limit(100).groupby('id').apply(calculate_courses_final_df)

It works but it is slower than pandas on a relatively large dataset.

Upvotes: 1

Views: 972

Answers (1)

ZygD
ZygD

Reputation: 24458

Using this input dataframe...

from pyspark.sql import types as T, functions as F
import pandas as pd

df_contracts_courses = spark.createDataFrame(
    [('a', 2.2),
     ('b', 7.7)],
    ['WD_Customer_ID', 'Avg_Num_Training'])

the following pandas_udf takes 1 input column and returns 1 output column:

@F.pandas_udf(T.DoubleType())
def calculate_courses_final_df(this_row: pd.Series) -> pd.Series:
   return this_row + 1

df_xu_final_rows_list  = df_contracts_courses.select(
    'WD_Customer_ID',
    calculate_courses_final_df('Avg_Num_Training').alias('final')
)
df_xu_final_rows_list.show()
# +--------------+-----+
# |WD_Customer_ID|final|
# +--------------+-----+
# |             a|  3.2|
# |             b|  8.7|
# +--------------+-----+

Upvotes: 1

Related Questions