Reputation: 406
Is it possible to add row and column statistics on the edges of a Seaborn heatmap?
So for each row on the right hand side I want to display the row mean (for each month), and at the bottom edge for year, I want to show the column means for each column.
Upvotes: 0
Views: 1859
Reputation: 12496
If you are working with a dataframe like this:
df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))
date value
0 1949-01-01 202
1 1949-02-01 535
2 1949-03-01 448
3 1949-04-01 370
4 1949-05-01 206
.. ... ...
139 1960-08-01 238
140 1960-09-01 598
141 1960-10-01 180
142 1960-11-01 491
143 1960-12-01 262
You have to re-shape in with pandas.DataFrame.pivot
:
df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')
year 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960
month
Apr 370 472 485 574 463 487 543 101 301 395 479 591
Aug 120 230 260 287 230 341 530 359 450 437 114 238
Dec 314 443 352 545 120 485 519 501 561 509 426 262
Feb 535 558 513 444 545 266 191 459 143 351 351 443
Jan 202 430 591 335 274 428 439 149 317 314 316 108
Jul 288 251 376 575 419 113 363 205 369 336 256 162
Jun 171 459 543 269 343 415 527 153 583 307 140 571
Mar 448 187 393 148 150 373 466 487 261 289 287 228
May 206 199 291 158 154 188 554 489 545 312 592 235
Nov 566 357 121 289 234 152 180 290 555 379 444 491
Oct 221 408 413 370 406 445 305 576 370 152 164 180
Sep 202 249 559 563 584 364 134 409 403 466 400 598
Then you can add a column with month mean and a row with year mean:
df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)
year 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 month_mean
month
Apr 370 472 485 574 463 487 543 101 301 395 479 591 438
Aug 120 230 260 287 230 341 530 359 450 437 114 238 299
Dec 314 443 352 545 120 485 519 501 561 509 426 262 419
Feb 535 558 513 444 545 266 191 459 143 351 351 443 399
Jan 202 430 591 335 274 428 439 149 317 314 316 108 325
Jul 288 251 376 575 419 113 363 205 369 336 256 162 309
Jun 171 459 543 269 343 415 527 153 583 307 140 571 373
Mar 448 187 393 148 150 373 466 487 261 289 287 228 309
May 206 199 291 158 154 188 554 489 545 312 592 235 326
Nov 566 357 121 289 234 152 180 290 555 379 444 491 338
Oct 221 408 413 370 406 445 305 576 370 152 164 180 334
Sep 202 249 559 563 584 364 134 409 403 466 400 598 410
year_mean 303 353 408 379 326 338 395 348 404 353 330 342 357
Alternatively, you can pivot and compute means all together with pandas.pivot_table
:
df = pd.pivot_table(data = df, columns = 'year', index = 'month', values = 'value', margins = True)
year 1949 1950 1951 1952 1953 1954 1955 1956 1957 1958 1959 1960 All
month
Apr 370 472 485 574 463 487 543 101 301 395 479 591 438
Aug 120 230 260 287 230 341 530 359 450 437 114 238 299
Dec 314 443 352 545 120 485 519 501 561 509 426 262 419
Feb 535 558 513 444 545 266 191 459 143 351 351 443 399
Jan 202 430 591 335 274 428 439 149 317 314 316 108 325
Jul 288 251 376 575 419 113 363 205 369 336 256 162 309
Jun 171 459 543 269 343 415 527 153 583 307 140 571 373
Mar 448 187 393 148 150 373 466 487 261 289 287 228 309
May 206 199 291 158 154 188 554 489 545 312 592 235 326
Nov 566 357 121 289 234 152 180 290 555 379 444 491 338
Oct 221 408 413 370 406 445 305 576 370 152 164 180 334
Sep 202 249 559 563 584 364 134 409 403 466 400 598 410
All 303 353 408 379 326 338 395 348 404 353 330 342 357
The only difference is last column and last row names.
Now you are ready to draw the heatmap:
fig, ax = plt.subplots()
sns.heatmap(ax = ax, data = df, annot = True, fmt = '.0f')
plt.show()
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))
df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')
df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)
fig, ax = plt.subplots()
sns.heatmap(ax = ax, data = df, annot = True, fmt = '.0f')
plt.show()
Optionally, you can change the colormap of the last column and the last row, in order to improve visibility:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
df = pd.DataFrame({'date': pd.date_range(start = '1949-01-01', end = '1960-12-01', freq = 'MS')})
df['value'] = np.random.randint(100, 600, len(df))
df['month'] = df['date'].dt.month_name().str.slice(stop = 3).sort_values()
df['year'] = df['date'].dt.year
df = df.pivot(columns = 'year', index = 'month', values = 'value')
df['month_mean'] = df.mean(axis = 1)
df.loc['year_mean'] = df.mean(axis = 0)
df_values = df.copy()
df_values['month_mean'] = float('nan')
df_values.loc['year_mean'] = float('nan')
df_means = df.copy()
df_means.loc[:-1, :-1] = float('nan')
fig, ax = plt.subplots()
sns.heatmap(ax = ax, data = df_values, annot = True, fmt = '.0f', cmap = 'Reds', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())
sns.heatmap(ax = ax, data = df_means, annot = True, fmt = '.0f', cmap = 'Blues', vmin = df.to_numpy().min(), vmax = df.to_numpy().max())
plt.show()
Upvotes: 4