Rv R
Rv R

Reputation: 311

split pyspark dataframe into multiple dataframes based on a condition

I have a pyspark dataframe which contains the data similar to below:

id  class price  place
1   A      10      US
2   B      5       US
3   B      5       MEXICO
4   A     -20      CANADA
5   C     -15      US
6   C     -5       US
7   D      20      MEXICO
8   A      10      CANADA
9   A     -30      CANADA

I want to find the sum of the price column with respect to column 'class', which can be achieved to some extent by applying groupby

      df.groupby('class ').agg({'price': 'sum'}).show()

output: class   sum(price)
        A       -30
        B        10
        C       -20
        D        20

and now I want to split the data based on the sum(price) obtained. If sum(price) with respect to 'class' is greater than 'zero' then this data should go into one dataframe.(B, D class in this case)

id  class price place
2   B     5      US
3   B     5      MEXICO
7   D     20     MEXICO

If sum(price) with respect to 'class' is less than 'zero' then this data should go into one dataframe.(A, C class in this case)

id  class price place
1   A   10     US
8   A   10     CANADA
4   A   -20     CANADA
9   A   -30    CANADA
5   C   -15    US
6   C   -5     US

The data is further written as a two different csv file using pyspark.

df.write.format('csv').option('header', 'true').save(destination_location)

How to store the groupby result into a dataframe? and how to achieve the split of the single dataframe into two different dataframes based on the above condition?

Upvotes: 4

Views: 4719

Answers (1)

mck
mck

Reputation: 42422

You can use a sum over a window, and split the dataframe into two using two filters. You may want to take care of the case where sum = 0.

from pyspark.sql import functions as F, Window

summed = df.withColumn('sum', F.sum('price').over(Window.partitionBy('class')))
df1 = summed.filter('sum > 0')
df2 = summed.filter('sum < 0')

Upvotes: 2

Related Questions