frenchast
frenchast

Reputation: 33

Conditional counting in Pyspark

I have the following code:

output = (assignations
          .join(activations,['customer_id','external_id'],'left')
          .join(redeemers,['customer_id','external_id'],'left')
          .groupby('external_id')
          .agg(f.expr('COUNT(DISTINCT(CASE WHEN assignation = 1 THEN customer_id ELSE NULL END))').alias('assigned'),
           f.expr('COUNT(DISTINCT(CASE WHEN activation = 1 THEN customer_id ELSE NULL END))').alias('activated'),
           f.expr('COUNT(DISTINCT(CASE WHEN redeemer = 1 THEN customer_id ELSE NULL END))').alias('redeemed'))
         )

This code gives me the following output:

external_id          assigned     activated    redeemed
DISC0000089309         31968         901         491
DISC0000089428         31719         893         514
DISC0000089283         2617           60          39

My idea is to transform the case when part into a more Pythonic/Pyspark code. That's why I tried the following code:

output = (assignations
          .join(activations,['customer_id','external_id'],'left')
          .join(redeemers,['customer_id','external_id'],'left')
          .groupby('external_id')
          .agg(f.count(f.when(f.col('assignation')==1,True).alias('assigned')),
           f.count(f.when(f.col('activation')==1,True).alias('activated')),
           f.count(f.when(f.col('redeemer')==1,True).alias('redeem'))
         ))

The problem is that output is not the same, numbers don't match. How can I convert the code in order to get the same output?

Upvotes: 0

Views: 351

Answers (1)

mck
mck

Reputation: 42332

You can use f.countDistinct to achieve the equivalent of COUNT(DISTINCT ) in Spark SQL:

output = (assignations
          .join(activations,['customer_id','external_id'],'left')
          .join(redeemers,['customer_id','external_id'],'left')
          .groupby('external_id')
          .agg(
              f.countDistinct(f.when(f.col('assignation') == 1, f.col('customer_id'))).alias('assigned'),
              f.countDistinct(f.when(f.col('activation') == 1, f.col('customer_id'))).alias('activated'),
              f.countDistinct(f.when(f.col('redeemer') == 1, f.col('customer_id'))).alias('redeemed')
          )
         )

Upvotes: 2

Related Questions