Kyuu
Kyuu

Reputation: 1045

How to filter pivot tables on python

How do I filter pivot tables to return specific columns. Currently my dataframe is this:

print table
                    sum            
Sex              Female  Male   All
Date (Intervals)                   
April               166   191   357
August              212   263   475
December            173   263   436
February            192   298   490
January             148   195   343
July                189   260   449
June                165   238   403
March               165   278   443
May                 236   253   489
November            167   247   414
October             185   287   472
September           175   306   481
All                2173  3079  5252

I want to display results of only the male column. I tried the following code:

table.query('Sex == "Male"')

However I got this error

TypeError: Expected tuple, got str

How would I be able to filter my table with specified rows or columns.

Upvotes: 2

Views: 10535

Answers (1)

unutbu
unutbu

Reputation: 879201

It looks like table has a column MultiIndex:

                    sum            
Sex              Female  Male   All

One way to check if your table has a column MultiIndex is to inspect table.columns:

In [178]: table.columns
Out[178]: 
MultiIndex(levels=[['sum'], ['All', 'Female', 'Male']],
           labels=[[0, 0, 0], [1, 2, 0]],
           names=[None, 'sex'])

To access a column of table you need to specify a value for each level of the MultiIndex:

In [179]: list(table.columns)
Out[179]: [('sum', 'Female'), ('sum', 'Male'), ('sum', 'All')]

Thus, to select the Male column, you would use

In [176]: table[('sum', 'Male')]
Out[176]: 
date
April         42.0
August        34.0
December      32.0
...

Since the sum level is unnecessary, you could get rid of it by specifying the values parameter when calling df.pivot or df.pivot_table.

table2 = df.pivot_table(index='date', columns='sex', aggfunc='sum', margins=True,
                        values='sum')
# sex        Female   Male     All
# date                            
# April        40.0   40.0    80.0
# August       48.0   32.0    80.0
# December     48.0   44.0    92.0

For example,

import numpy as np
import pandas as pd
import calendar
np.random.seed(2016)
N = 1000
sex = np.random.choice(['Male', 'Female'], size=N)
date = np.random.choice(calendar.month_name[1:13], size=N)
df = pd.DataFrame({'sex':sex, 'date':date, 'sum':1})

# This reproduces a table similar to yours
table = df.pivot_table(index='date', columns='sex', aggfunc='sum', margins=True)
print(table[('sum', 'Male')])

# table2 has a single level Index
table2 = df.pivot_table(index='date', columns='sex', aggfunc='sum', margins=True,
                        values='sum')
print(table2['Male'])

Another way to remove the sum level would be to use table = table['sum'], or table.columns = table.columns.droplevel(0).

Upvotes: 4

Related Questions