Sri
Sri

Reputation: 77

Spark dataframe foreachPartition: sum the elements using pyspark

I am trying to partition spark dataframe and sum elements in each partition using pyspark. But I am unable to do this inside a called function "sumByHour". Basically, I am unable to access dataframe columns inside "sumByHour".

Basically, I am partitioning by "hour" column and trying to sum the elements based on "hour" partition. So expected output is: 6,15,24 for 0,1,2 hour respectively. Tried below with no luck.

from pyspark.sql.functions import * 
from pyspark.sql.types import *

import pandas as pd

def sumByHour(ip):
    print(ip)

pandasDF = pd.DataFrame({'hour': [0,0,0,1,1,1,2,2,2], 'numlist': [1,2,3,4,5,6,7,8,9]})
myschema = StructType(
                    [StructField('hour', IntegerType(), False),
                     StructField('numlist', IntegerType(), False)] 
                  )
 myDf = spark.createDataFrame(pandasDF, schema=myschema)
 mydf = myDf.repartition(3, "hour")
 myDf.foreachPartition(sumByHour)

I am able to solve this with "window.partitionBy". But I want to know if it can be solved by "foreachPartition".

Thanks in Advance,

Sri

Upvotes: 2

Views: 1310

Answers (2)

Matt Andruff
Matt Andruff

Reputation: 5125

Thanks for the code sample it made this easy. Here's a really simple example modifies you sumByHour code:

def sumByHour(ip):
  mySum = 0
  myPartition = ""
  for x in ip:
   mySum += x.numlist
   myPartition = x.hour
  myString = '{}_{}'.format(mySum, myPartition)
  print(myString)

mydf = myDf.repartition(5,"hour") #wait 5 I wanted 3!!!

You get almost the expected result:

>>> mydf.foreachPartition(sumByHour)
0_
0_
24_2
6_0
15_1
>>> 

You might ask why partition by '5' and not the '3'? Well turns out the hash formula used for 3 partitions has collision for (0,1) into the same partition and then has an empty partition.(Bad luck) So this will work but, you only want to use it on an array that will fit into memory.

Upvotes: 2

BoomBoxBoy
BoomBoxBoy

Reputation: 1885

You can use a Window to do that and add the sumByHour as a new column.

from pyspark.sql import functions, Window

w = Window.partitionBy("hour")

myDf = myDf.withColumn("sumByHour", functions.sum("numlist").over(w))
myDf.show()

+----+-------+---------+
|hour|numlist|sumByHour|
+----+-------+---------+
|   1|      4|       15|
|   1|      5|       15|
|   1|      6|       15|
|   2|      7|       24|
|   2|      8|       24|
|   2|      9|       24|
|   0|      1|        6|
|   0|      2|        6|
|   0|      3|        6|
+----+-------+---------+

Upvotes: 1

Related Questions