Vivian
Vivian

Reputation: 105

Apply groupby in udf from a increase function Pyspark

I have the follow function:

import copy

rn = 0

def check_vals(x, y):
    global rn
    
    if (y != None) & (int(x)+1) == int(y):
        return rn + 1
    else:
        # Using copy to deepcopy and not forming a shallow one.
        res = copy.copy(rn)
        # Increment so that the next value with start form +1
        rn += 1
        # Return the same value as we want to group using this
        return res + 1
    
    return 0

@pandas_udf(IntegerType(), functionType=PandasUDFType.GROUPED_AGG)
def check_final(x, y):
    return lambda x, y: check_vals(x, y)

I need apply this function in a follow df:

index  initial_range  final_range
1         1              299
1        300             499
1        500             699
1        800             1000
2        10              99
2        100             199

So I need that follow output:

index  min_val   max_val
1        1         699
1        800       1000
2        10        199

See, that the grouping field there are a news abrangencies, that are the values min(initial) and max(final), until the sequence is broken, applying the groupBy.

I tried:

w = Window.partitionBy('index').orderBy(sf.col('initial_range'))

df = (df.withColumn('nextRange', sf.lead('initial_range').over(w))
       .fillna(0,subset=['nextRange'])
       .groupBy('index')
       .agg(check_final("final_range", "nextRange").alias('check_1'))
       .withColumn('min_val', sf.min("initial_range").over(Window.partitionBy("check_1")))
       .withColumn('max_val', sf.max("final_range").over(Window.partitionBy("check_1")))
       )

But, don't worked. Anyone can help me?

Upvotes: 0

Views: 75

Answers (1)

Jonathan
Jonathan

Reputation: 2043

I think pure Spark SQL API can solve your question and it doesn't need to use any UDF, which might be an impact of your Spark performance. Also, I think two window function is enough to solve this question:

df.withColumn(
    'next_row_initial_diff', func.col('initial_range')-func.lag('final_range', 1).over(Window.partitionBy('index').orderBy('initial_range'))
).withColumn(
    'group', func.sum(
        func.when(func.col('next_row_initial_diff').isNull()|(func.col('next_row_initial_diff')==1), func.lit(0))
            .otherwise(func.lit(1))
    ).over(
        Window.partitionBy('index').orderBy('initial_range')
    )
).groupBy(
    'group', 'index'
).agg(
    func.min('initial_range').alias('min_val'),
    func.max('final_range').alias('max_val')
).drop(
    'group'
).show(100, False)

Column next_row_initial_diff: Just like the lead you use to shift/lag the row and check if it's in sequence.

Column group: To group the sequence in index partition.

Upvotes: 1

Related Questions