Amber Z.
Amber Z.

Reputation: 379

PySpark conditional increment

I'm brand new to PySpark and I'm trying to convert some python code that derives a new variable 'COUNT_IDX'. The new variable has an initial value of 1, but is incremented by 1 when a condition is met. Otherwise the new variable value will be the same value as it was on the last record.

The condition to increment is when: TRIP_CD not equal to the previous record TRIP_CD or SIGN not equal to the previous record SIGN or time_diff not equal 1.

Python code (pandas dataframe):

df['COUNT_IDX'] = 1

for i in range(1, len(df)):
    if ((df['TRIP_CD'].iloc[i] != df['TRIP_CD'].iloc[i - 1])
          or (df['SIGN'].iloc[i] != df['SIGN'].iloc[i-1])
          or df['time_diff'].iloc[i] != 1):
        df['COUNT_IDX'].iloc[i] = df['COUNT_IDX'].iloc[i-1] + 1
    else:
        df['COUNT_IDX'].iloc[i] = df['COUNT_IDX'].iloc[i-1]

Here is the expected results:

TRIP_CD   SIGN   time_diff  COUNT_IDX
2711      -      1          1
2711      -      1          1
2711      +      2          2
2711      -      1          3
2711      -      1          3
2854      -      1          4
2854      +      1          5

In PySpark, I initialize COUNT_IDX as 1. Then using the Window function, I took the lags of TRIP_CD and SIGN and calculated the time_diff, then tried:

df = sqlContext.sql('''
   select TRIP, TRIP_CD, SIGN, TIME_STAMP, seconds_diff,
   case when TRIP_CD != TRIP_lag or SIGN != SIGN_lag  or  seconds_diff != 1 
        then (lag(COUNT_INDEX) over(partition by TRIP order by TRIP, TIME_STAMP))+1
        else (lag(COUNT_INDEX) over(partition by TRIP order by TRIP, TIME_STAMP)) 
        end as COUNT_INDEX from df''')

This is giving me something like:

TRIP_CD   SIGN   time_diff  COUNT_IDX
2711      -      1          1
2711      -      1          1
2711      +      2          2
2711      -      1          2
2711      -      1          1
2854      -      1          2
2854      +      1          2

If COUNT_IDX is updated on a previous record, COUNT_IDX on the current record isn't recognizing that change to calculate. It's like the COUNTI_IDX is not being overwritten or it's not being evaluated from row to row. Any ideas at how I can get around this?

Upvotes: 2

Views: 2617

Answers (1)

zero323
zero323

Reputation: 330183

You need cumulative sum here:

-- cumulative sum
SUM(CAST(  
  -- if at least one condition has been satisfied
  -- we take 1 otherwise 0
  TRIP_CD != TRIP_lag OR SIGN != SIGN_lag OR seconds_diff != 1 AS LONG
)) OVER W
...
WINDOW W AS (PARTITION BY trip ORDER BY times_stamp)

Upvotes: 1

Related Questions