Pj-
Pj-

Reputation: 440

Trim Group by Column/Series Sequence in Pandas by NaN Ocurrence

I have a data frame as follows:

user_id metric_date metric_val1 is_churn
3 2021-01 NaN True
3 2021-02 NaN True
3 2021-03 0.4 False
3 2021-04 0.5 False
3 2021-05 NaN True
4 2021-01 0.1 False
4 2021-02 0.3 False
4 2021-03 0.2 False
4 2021-04 NaN True
4 2021-05 NaN True

Suppose there are other metric columns, but the main reference is metric_val1, how can I grouping by user_id and trim all row that have NaN value before the first valid metric_val1, and keeping only the last NaN after the last valid value for metric_val1, the output should be something like that (Assume that there is no gap in the valid values) :

user_id metric_date metric_val1 is_churn
3 2021-03 0.4 False
3 2021-04 0.5 False
3 2021-05 NaN True
4 2021-01 0.1 False
4 2021-02 0.3 False
4 2021-03 0.2 False
4 2021-04 NaN True

Can someone help me with an efficient way to do that in pandas?

Upvotes: 0

Views: 103

Answers (1)

wwnde
wwnde

Reputation: 26676

Please boolean select all non Non values or NaN values which immediately follow non Nan values in a group and mask. Code below;

df[df.groupby('user_id')['metric_val1'].apply(lambda x : x.notna()|x.isna()&x.shift(1).notna())]



    user_id metric_date  metric_val1  is_churn
2        3     2021-03          0.4     False
3        3     2021-04          0.5     False
4        3     2021-05          NaN      True
5        4     2021-01          0.1     False
6        4     2021-02          0.3     False
7        4     2021-03          0.2     False
8        4     2021-04          NaN      True

If you have a large dataframe and are worried of memory and speed. Could try use pyspark. Just instantiate a pyspark session. Pyspark is scalable;

from  pyspark.sql.functions import *
import pyspark.sql.functions as F
from pyspark.sql import Window
k =Window.partitionBy('user_id').orderBy('user_id','metric_date')
(
  df.withColumn('t', lag('metric_val1').over(k))#Introduce column t which draws immediate preceding columns' value
  .filter((F.col('t')=='NaN')|(F.col('metric_val1')!='NaN'))#Filter out t is NaN or metric_val1 is not NaN
  .drop('t')#drop the temp column
).show()

+-------+-----------+-----------+--------+
|user_id|metric_date|metric_val1|is_churn|
+-------+-----------+-----------+--------+
|      3|    2021-02|        NaN|    true|
|      3|    2021-03|        0.4|   false|
|      3|    2021-04|        0.5|   false|
|      4|    2021-01|        0.1|   false|
|      4|    2021-02|        0.3|   false|
|      4|    2021-03|        0.2|   false|
|      4|    2021-05|        NaN|    true|
+-------+-----------+-----------+--------+

Upvotes: 2

Related Questions