cnns
cnns

Reputation: 171

Filter a grouped dataframe based on column value in pyspark

I have below df. I want to group the dataframe by company and date and for such grouped subsets filter row based on category prioritizing QF if available, if not SAF and if not then AF. I am trying to assign ranks using window function but maybe there is an easier way.

    company     date     value  category
    ------------------------------------
      xyz    31-12-2020    12      AF
      xyz    31-12-2020    10      SAF
      xyz    31-12-2020    11      QF
      xyz    30-06-2020    14      AF
      xyz    30-06-2020    16      SAF
      xyz    30-09-2020    13      SAF
      xyz    31-03-2019    20      AF

Expected output:

   company     date      value  category
    ------------------------------------
      xyz    31-12-2020    11      QF
      xyz    30-06-2020    16      SAF
      xyz    30-09-2020    13      SAF
      xyz    31-03-2019    20      AF

Upvotes: 2

Views: 573

Answers (3)

samkart
samkart

Reputation: 6654

We can assign a rank to the categories using when().otherwise() and retain the records that have the min rank in the group.

data_sdf. \
    withColumn('cat_rank',
               func.when(func.col('cat') == 'QF', func.lit(1)).
               when(func.col('cat') == 'SAF', func.lit(2)).
               when(func.col('cat') == 'AF', func.lit(3))
               ). \
    withColumn('min_cat_rank', 
               func.min('cat_rank').over(wd.partitionBy('company', 'dt'))
               ). \
    filter(func.col('min_cat_rank').isNotNull()). \
    filter(func.col('min_cat_rank') == func.col('cat_rank')). \
    drop('cat_rank', 'min_cat_rank'). \
    show()

# +-------+----------+---+---+
# |company|        dt|val|cat|
# +-------+----------+---+---+
# |    xyz|30-09-2020| 13|SAF|
# |    xyz|30-06-2020| 16|SAF|
# |    xyz|31-03-2019| 20| AF|
# |    xyz|31-12-2020| 11| QF|
# +-------+----------+---+---+

Upvotes: 1

Ric S
Ric S

Reputation: 9277

Supposing there can be multiple values for the same category in a combination of company and date, and that we want to keep the maximum value for the preferred category, here is a solution with two window functions:

import pyspark.sql.functions as F
from pyspark.sql.window import Window

w_company_date = Window.partitionBy('company', 'date')
w_company_date_category = Window.partitionBy('company', 'date', 'category')

df = (df
  .withColumn('priority', F.when(F.col('category') == 'QF', 1)
                           .when(F.col('category') == 'SAF', 2)
                           .when(F.col('category') == 'AF', 3)
                           .otherwise(None))
  .withColumn('top_choice', F.when((F.col('priority') == F.min('priority').over(w_company_date))
                                   & (F.col('value') == F.max('value').over(w_company_date_category)), 1)
                             .otherwise(0))
  .filter(F.col('top_choice') == 1)
  .drop('priority', 'top_choice')
)

df.show()

+-------+----------+-----+--------+
|company|      date|value|category|
+-------+----------+-----+--------+
|    xyz|2020-03-31|   20|      AF|
|    xyz|2020-06-30|   16|     SAF|
|    xyz|2020-09-30|   13|     SAF|
|    xyz|2020-12-31|   11|      QF|
+-------+----------+-----+--------+

Upvotes: 1

elyptikus
elyptikus

Reputation: 1148

Assuming, that there are only a limited amount of categories and that there are no duplicated entries for each categories I would suggest to map the categories to integers to which you can order them. Afterwards you can simply partition, sort and pick the first entry of each partition.

df = df.withColumn('mapping',
            f.when(f.col('category') == 'QF', f.lit('1')).otherwise(
            f.when(f.col('category') == 'SAF', f.lit('2')).otherwise(
            f.when(f.col('category') == 'AF', f.lit('3')).otherwise(f.lit(None)))))

w = Window.partitionBy('date').orderBy(f.col('mapping'))
df.withColumn('row', f.row_number().over(w))\
   .filter(f.col('row') == 1)\
   .drop('row', 'mapping')\
   .show()

Upvotes: 1

Related Questions