Aviral Srivastava
Aviral Srivastava

Reputation: 4582

How to have nested numpy()'s np.where, or, one after the other?

I have a dataframe in which I need to add a column based on a certain condition. I am successfully doing this(How to have list's elements as a condition in np.where()?). However, when I apply the same logic twice, it does not work.

my datframe is:

period period_type
JAN16 month
JAN16 YTD
2017 2017

What I want instead is: 2017 annual. However, I get annual for all the values, i.e. months, YTD etc get changed to annual. Code block :

def add_period_type(df):
    months = ['JAN', 'FEB', 'MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC']
    m = df.period.str.startswith(tuple(months))
    df['period_type'] = np.where(m, 'month', df.period.str.split().str[0])
    df.loc[~m, 'period'] = df.loc[~m, 'period'].str.split().str[1]
    df["period"] = df["period"].combine_first(df["period_type"])
    years = [str(x) for x in range(2000, 2100)]
    y = df.period.str == (tuple(years))
    print(y)
    df['period_type'] = np.where(y, 'annual', df.period_type.str)
    return df

The first 3-4 lines add a new column period_type. I, then, want to modify this column a bit based on the aforementioned condition(check whether the value is a year and if it is, assigns annual to the period_type. Instead, thbis code is not working, it assigns annual to all.

Upvotes: 2

Views: 1187

Answers (2)

Kosmonaut
Kosmonaut

Reputation: 128

In terms of performance, I find that nested np.where statements generally perform better than np.select (though np.select seems to do same at larger sizes).

import numpy as np
import pandas as pd


sizes = [100,10000,1000000]

for n in sizes:
    x = np.random.randint(9, size=n)
    y = np.random.randint(2, size=n)
    z = np.random.choice(['N','Y'],size=n)
    c = [((y == 0) & (z == 'Y')),((y != 0) & (z == 'Y')),((y != 0) & (z == 'N'))]
    o = [x,x * 2,x * 3]
    %timeit np.select(c,o,0)
    %timeit np.where(c[0],o[0],np.where(c[1],o[1],np.where(c[2],o[2],0)))


38.2 µs ± 281 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
3.43 µs ± 145 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
167 µs ± 1.45 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
88.4 µs ± 909 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
15.9 ms ± 177 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
16.9 ms ± 181 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Upvotes: 1

jxc
jxc

Reputation: 13998

Use np.select():

str = """period
JAN16
YTD JAN16 
2017"""

# sample dataframe
df = pd.read_csv(pd.io.common.StringIO(str))

months = ['JAN', 'FEB', 'MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC']
years = [ '{}'.format(x) for x in range(2000, 2100)]

# condition for month
m = df.period.str[:3].isin(months)

# condition for annual 
y = df.period.isin(years)

# if contains spaces, then do JAN16, YTD
n = df.period.str.contains('\s')

df['period_type'] = np.select([m, y, n], ['month', 'annual', df.period.str.split().str[::-1].str.join(', ')])
df
#      period period_type
#0      JAN16       month
#1  YTD JAN16  JAN16, YTD
#2       2017      annual

Upvotes: 2

Related Questions