Ferhat
Ferhat

Reputation: 406

Seaborn heatmap - row and column statistics to display

Is it possible to add row and column statistics on the edges of a Seaborn heatmap?

So for each row on the right hand side I want to display the row mean (for each month), and at the bottom edge for year, I want to show the column means for each column.

enter image description here

Upvotes: 0

Views: 1859

Answers (1)

Zephyr
Zephyr

Reputation: 12496

If you are working with a dataframe like this:

df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))
          date  value
0   1949-01-01    202
1   1949-02-01    535
2   1949-03-01    448
3   1949-04-01    370
4   1949-05-01    206
..         ...    ...
139 1960-08-01    238
140 1960-09-01    598
141 1960-10-01    180
142 1960-11-01    491
143 1960-12-01    262

You have to re-shape in with pandas.DataFrame.pivot:

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')
year   1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960
month                                                                        
Apr     370   472   485   574   463   487   543   101   301   395   479   591
Aug     120   230   260   287   230   341   530   359   450   437   114   238
Dec     314   443   352   545   120   485   519   501   561   509   426   262
Feb     535   558   513   444   545   266   191   459   143   351   351   443
Jan     202   430   591   335   274   428   439   149   317   314   316   108
Jul     288   251   376   575   419   113   363   205   369   336   256   162
Jun     171   459   543   269   343   415   527   153   583   307   140   571
Mar     448   187   393   148   150   373   466   487   261   289   287   228
May     206   199   291   158   154   188   554   489   545   312   592   235
Nov     566   357   121   289   234   152   180   290   555   379   444   491
Oct     221   408   413   370   406   445   305   576   370   152   164   180
Sep     202   249   559   563   584   364   134   409   403   466   400   598

Then you can add a column with month mean and a row with year mean:

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)
year       1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960  month_mean
month                                                                                        
Apr         370   472   485   574   463   487   543   101   301   395   479   591         438
Aug         120   230   260   287   230   341   530   359   450   437   114   238         299
Dec         314   443   352   545   120   485   519   501   561   509   426   262         419
Feb         535   558   513   444   545   266   191   459   143   351   351   443         399
Jan         202   430   591   335   274   428   439   149   317   314   316   108         325
Jul         288   251   376   575   419   113   363   205   369   336   256   162         309
Jun         171   459   543   269   343   415   527   153   583   307   140   571         373
Mar         448   187   393   148   150   373   466   487   261   289   287   228         309
May         206   199   291   158   154   188   554   489   545   312   592   235         326
Nov         566   357   121   289   234   152   180   290   555   379   444   491         338
Oct         221   408   413   370   406   445   305   576   370   152   164   180         334
Sep         202   249   559   563   584   364   134   409   403   466   400   598         410
year_mean   303   353   408   379   326   338   395   348   404   353   330   342         357

Alternatively, you can pivot and compute means all together with pandas.pivot_table:

df = pd.pivot_table(data = df, columns = 'year', index = 'month', values = 'value', margins = True)
year   1949  1950  1951  1952  1953  1954  1955  1956  1957  1958  1959  1960  All
month                                                                             
Apr     370   472   485   574   463   487   543   101   301   395   479   591  438
Aug     120   230   260   287   230   341   530   359   450   437   114   238  299
Dec     314   443   352   545   120   485   519   501   561   509   426   262  419
Feb     535   558   513   444   545   266   191   459   143   351   351   443  399
Jan     202   430   591   335   274   428   439   149   317   314   316   108  325
Jul     288   251   376   575   419   113   363   205   369   336   256   162  309
Jun     171   459   543   269   343   415   527   153   583   307   140   571  373
Mar     448   187   393   148   150   373   466   487   261   289   287   228  309
May     206   199   291   158   154   188   554   489   545   312   592   235  326
Nov     566   357   121   289   234   152   180   290   555   379   444   491  338
Oct     221   408   413   370   406   445   305   576   370   152   164   180  334
Sep     202   249   559   563   584   364   134   409   403   466   400   598  410
All     303   353   408   379   326   338   395   348   404   353   330   342  357

The only difference is last column and last row names.
Now you are ready to draw the heatmap:

fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df, annot = True, fmt = '.0f')

plt.show()

Complete Code

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)


fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df, annot = True, fmt = '.0f')

plt.show()

enter image description here


Optionally, you can change the colormap of the last column and the last row, in order to improve visibility:

import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np


df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))

df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')

df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)

df_values = df.copy()
df_values['month_mean'] = float('nan')
df_values.loc['year_mean'] = float('nan')

df_means = df.copy()
df_means.loc[:-1, :-1] = float('nan')


fig, ax = plt.subplots()

sns.heatmap(ax = ax, data = df_values, annot = True, fmt = '.0f', cmap = 'Reds', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())
sns.heatmap(ax = ax, data = df_means, annot = True, fmt = '.0f', cmap = 'Blues', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())

plt.show()

enter image description here

Upvotes: 4

Related Questions