AndreyIto
AndreyIto

Reputation: 974

Python heatmap: distorting colour mapping

I have a set of values like (just imagine that these are widget sales):

year | Sarah | Elizabeth | Jones | Robert |
-------------------------------------------
2003 | 11    | 0         |  0    | 0      |
2004 | 16    | 0         |  0    | 6      |
2005 | 12    | 0         |  4    | 11     |
2006 | 33    | 0         |  0    | 3      |
2007 | 18    | 0         |  0    | 0      |
2008 | 18    | 0         |  0    | 0      |
2009 | 110   | 0         |  0    | 0      |
2010 | 83    | 0         |  0    | 0      |
2011 | 1553  | 20        |  25   | 0      |
2012 | 785   | 27        |  0    | 186    |
2013 | 561   | 73        |  0    | 3      |

I did a heatmap with seaborn and matplotlib, but the large numbers dominate, unfortunately, so it looks like one very dark square amongst very pale squares.

Is there a way to do the colour mapping with a piece-wise function, e.g. map the whole range consisting of: [0, 200), then (550, 1600], to a linear, unbroken colour values? Unfortunately, all the colourmaps I've seen so far are preset.

Upvotes: 1

Views: 459

Answers (1)

Y. Luo
Y. Luo

Reputation: 5722

It seems that you are looking for discrete bounds method for Colormap Normalization. Here is an example using your data:

import StringIO

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import numpy as np
import pandas as pd
import seaborn as sns

s = """year Sarah Elizabeth Jones Robert
2003 11 0 0 0
2004 16 0 0 6
2005 12 0 4 11
2006 33 0 0 3
2007 18 0 0 0
2008 18 0 0 0
2009 110 0 0 0
2010 83 0 0 0
2011 1553 20 25 0
2012 785 27 0 186
2013 561 73 0 3"""
df = pd.read_table(StringIO.StringIO(s), sep=' ', header=0, index_col=0)

fig = plt.figure(figsize=(16, 6))
sns.set(font_scale=2) 

ax1 = plt.subplot(121)
sns.heatmap(df, ax=ax1)
ylabels = ax1.get_yticklabels()
plt.setp(ylabels, rotation=0)

ax2=plt.subplot(122)
bounds = np.concatenate((np.linspace(0, 200, 1050), np.linspace(550, 1600, 1050)))
norm = colors.BoundaryNorm(boundaries=bounds, ncolors=256)
sns.heatmap(df, ax=ax2, norm=norm)
ylabels = ax2.get_yticklabels()
plt.setp(ylabels, rotation=0)
cbar = ax2.collections[0].colorbar
cbar_ticks = np.concatenate((np.arange(0, 201, 50), np.arange(700, 1601, 200)))
cbar.set_ticks(cbar_ticks)
cbar.set_ticklabels(cbar_ticks)

fig.tight_layout()
plt.show()

enter image description here

Upvotes: 2

Related Questions