nilesh1212
nilesh1212

Reputation: 1655

How to execute custom logic at pyspark window partition

I have a dataframe in the format shown below, where we will have multiple entries of DEPNAME as shown below, my requirement is to set the result = Y at the DEPNAME level if either flag_1 or flag_2= Y, if both the flag i.e. flag_1 and flag_2 = N the result will be set as N as shown for DEPNAME=personnel

I am able to get the desired result using joins but I am curious if we can do it using window functions as the dataset is quite huge in size.

+---------+------+------+-+------+
|  depName|flag_1|flag_2| result |
+---------+------+------+-+------+
|    sales|    N|  Y    |  Y    |
|    sales|    N|  N    |  Y    |
|    sales|    N|  N    |  Y    |
|personnel|    N|  N    |  N    |
|personnel|    N|  N    |  N    |
|  develop|    Y|  N    |  Y    |
|  develop|    N|  N    |  Y    |
|  develop|    N|  N    |  Y    |
|  develop|    N|  N    |  Y    |
|  develop|    N|  N    |  Y    |
+---------+-----+------+ +------+

Upvotes: 1

Views: 312

Answers (2)

Gordon Linoff
Gordon Linoff

Reputation: 1270401

This answers the original version of the question.

This looks like a case expression:

select t.*,
       (case when flag_1 = 'Y' or flag_2 = 'Y'
             then 'Y' else 'N'
        end) as result

For the updated version:

select t.*,
       max(case when flag_1 = 'Y' or flag_2 = 'Y'
                then 'Y' else 'N'
           end) over (partition by depname) as result

Upvotes: 1

Ric S
Ric S

Reputation: 9257

If you are using PySpark (since you included it in the tags) and say that your dataframe is called df, you can use

import pyspark.sql.functions as F
from pyspark.sql.window import Window

w = Window.partitionBy('depName')

df = df\
  .withColumn('cnt', F.sum(F.when((F.col('flag_1') == 'Y') | (F.col('flag_2') == 'Y'), 1).otherwise(0)).over(w))\
  .withColumn('result', F.when(F.col('cnt') >= 1, 'Y').otherwise('N'))

df.show()

+---------+------+------+---+------+
|  depName|flag_1|flag_2|cnt|result|
+---------+------+------+---+------+
|  develop|     Y|     N|  1|     Y|
|  develop|     N|     N|  1|     Y|
|  develop|     N|     N|  1|     Y|
|  develop|     N|     N|  1|     Y|
|  develop|     N|     N|  1|     Y|
|personnel|     N|     N|  0|     N|
|personnel|     N|     N|  0|     N|
|    sales|     N|     Y|  1|     Y|
|    sales|     N|     N|  1|     Y|
|    sales|     N|     N|  1|     Y|
+---------+------+------+---+------+

Basically, within each partition determined by depName, you count how many times the condition flag_1 == 'Y' | flag_2 == 'Y' occurs, and you store it in cnt for all rows of that partition.
Then, you use a simple .when indicating with 'Y' all groups that have cnt >= 1.

Upvotes: 1

Related Questions