NineWasps
NineWasps

Reputation: 2253

Pandas: how to shift values in MultiIndex columns

I have dataframe

id    week    score                       target
              median   mean   min   max   count
11     1       0.05    0.06   0.01  0.35   12
11     2       0.07    0.09   0.03  0.49   20
11     3       0.15    0.20   0.11  0.70   10
12     2       0.10    0.15   0.02  0.19   19
12     3       0.50    0.59   0.10  0.90   30

And I try to shift values in score and target with

data['score_upd'] = data.groupby('id')['score'].shift()
data['target_upd'] = data.groupby('id')['target'].shift()

to get this result

id    week    score                       target
              median   mean   min   max   count
11     1       Nan     Nan    Nan   Nan    Nan
11     2       0.05    0.06   0.01  0.35   12
11     3       0.07    0.09   0.03  0.49   20
12     2       Nan     Nan    Nan   Nan    Nan
12     3       0.10    0.15   0.02  0.19   19

But I got an error

KeyError: 'score'

I guess the main problem that data.columns is MultiIndex. I tried to change it but nothing happened. How can I shift it based on the current conditions?

Upvotes: 1

Views: 84

Answers (1)

Henry Ecker
Henry Ecker

Reputation: 35646

One option would be to use get_loc to select specific columns:

score_cols = data.columns[data.columns.get_loc('score')]
data[score_cols] = (
    data.groupby('id')[score_cols].shift()
)
target_cols = data.columns[data.columns.get_loc('target')]
data[target_cols] = (
    data.groupby('id')[target_cols].shift()
)
   id week  score                   target
           median  mean   min   max  count
0  11    1    NaN   NaN   NaN   NaN    NaN
1  11    2   0.05  0.06  0.01  0.35   12.0
2  11    3   0.07  0.09  0.03  0.49   20.0
3  12    2    NaN   NaN   NaN   NaN    NaN
4  12    3   0.10  0.15  0.02  0.19   19.0

Or | together for one step:

cols = data.columns[data.columns.get_loc('score') |
                    data.columns.get_loc('target')]
data[cols] = (
    data.groupby('id')[cols].shift()
)
   id week  score                   target
           median  mean   min   max  count
0  11    1    NaN   NaN   NaN   NaN    NaN
1  11    2   0.05  0.06  0.01  0.35   12.0
2  11    3   0.07  0.09  0.03  0.49   20.0
3  12    2    NaN   NaN   NaN   NaN    NaN
4  12    3   0.10  0.15  0.02  0.19   19.0

*It is worth noting that simply slicing columns would also work if they are all grouped together:

cols = data.columns[2:]
data[cols] = (
    data.groupby('id')[cols].shift()
)
   id week  score                   target
           median  mean   min   max  count
0  11    1    NaN   NaN   NaN   NaN    NaN
1  11    2   0.05  0.06  0.01  0.35   12.0
2  11    3   0.07  0.09  0.03  0.49   20.0
3  12    2    NaN   NaN   NaN   NaN    NaN
4  12    3   0.10  0.15  0.02  0.19   19.0

Add an Index.map to add a prefix to a level:

cols = data.columns[data.columns.get_loc('score') |
                    data.columns.get_loc('target')]
data[cols.map(lambda c: (f'{c[0]}_upd', c[1]))] = (
    data.groupby('id')[cols].shift()
)
   id week  score              ... score_upd                   target_upd
           median  mean   min  ...    median  mean   min   max      count
0  11    1   0.05  0.06  0.01  ...       NaN   NaN   NaN   NaN        NaN
1  11    2   0.07  0.09  0.03  ...      0.05  0.06  0.01  0.35       12.0
2  11    3   0.15  0.20  0.11  ...      0.07  0.09  0.03  0.49       20.0
3  12    2   0.10  0.15  0.02  ...       NaN   NaN   NaN   NaN        NaN
4  12    3   0.50  0.59  0.10  ...      0.10  0.15  0.02  0.19       19.0

DataFrame and imports:

import pandas as pd

data = pd.DataFrame({
    ('id', ''): {0: 11, 1: 11, 2: 11, 3: 12, 4: 12},
    ('week', ''): {0: 1, 1: 2, 2: 3, 3: 2, 4: 3},
    ('score', 'median'): {0: 0.05, 1: 0.07, 2: 0.15, 3: 0.1, 4: 0.5},
    ('score', 'mean'): {0: 0.06, 1: 0.09, 2: 0.2, 3: 0.15, 4: 0.59},
    ('score', 'min'): {0: 0.01, 1: 0.03, 2: 0.11, 3: 0.02, 4: 0.1},
    ('score', 'max'): {0: 0.35, 1: 0.49, 2: 0.7, 3: 0.19, 4: 0.9},
    ('target', 'count'): {0: 12, 1: 20, 2: 10, 3: 19, 4: 30}
})

How the column selection works with get_loc:

# data.columns

MultiIndex([(    'id',       ''),
            (  'week',       ''),
            ( 'score', 'median'),
            ( 'score',   'mean'),
            ( 'score',    'min'),
            ( 'score',    'max'),
            ('target',  'count')],
           )

# data.columns.get_loc('score')

array([False, False,  True,  True,  True,  True, False])

# data.columns[data.columns.get_loc('score')]
 
MultiIndex([('score', 'median'),
            ('score',   'mean'),
            ('score',    'min'),
            ('score',    'max')],
           )

Upvotes: 1

Related Questions