Reputation: 3283
I am trying to evaluate, in pyspark, the sum of all elements of a dataframe. I wrote the following function
def sum_all_elements(df):
df = df.groupBy().sum()
df = df.withColumn('total', sum(df[colname] for colname in df.columns))
return df.select('total').collect()[0][0]
To speed up the function I have tried to convert to rdd and sum as
def sum_all_elements_pyspark(df):
res = df.rdd.map(lambda x: sum(x)).sum()
return res
But apparently the rdd function is slower than the dataframe's one. Is there a way to speed up the rdd function?
Upvotes: 0
Views: 84
Reputation: 5526
Dataframe functions are faster than rdd as Catalyst optimizer optimizes the actions performed over the dataframes but it doesn't do the same for rdd's.
WHen you execute actions over dataframe api it generates a optimized logical plan and that optimized logical plan is converted into multiple physical plans which then goes through the cost based optimization and choosing the best physical plan.
Now, the final physical plan is rdd equivalent code to execute because at low level rdd's are used. So using dataframe api based function will provide you the required performance boost.
Upvotes: 2