Tasha
Tasha

Reputation: 55

How to use the same colorbar for seaborn heatmaps?

I have three subplots which have slightly different colorbars when plotted individually

I plotted the heatmaps as follows:

fig, axes = plt.subplots(nrows=1, ncols=3, figsize = (20,10),sharey=True)

sns.heatmap(df,cmap=colormap,ax = axes[0])
sns.heatmap(df2,cmap=colormap,ax = axes[1])
sns.heatmap(df3,cmap=colormap,ax = axes[2])

I know that I can just set cbar = False for the first two plots, however as they are slightly different colorbars, the third colour bar will not represent all subplots. My first subplot ranges from 0-35, the second 0-36 and the third 0-37. I want a colorbar that encompasses the 0-37 range but obviously the colors will correspond to the wrong values for the first two subplots if I just do cbar = False.

How would I set up my subplots to contain just one colorbar which applies to all subplots, instead of three? Sorry about the lack of figures, I am unable to share them at this stage.

Upvotes: 1

Views: 3909

Answers (1)

tdy
tdy

Reputation: 41477

I want a colorbar that encompasses the 0-37 range but obviously the colors will correspond to the wrong values for the first two subplots if I just do cbar=False.

In addition to setting cbar=False on the first two heatmaps, anchor all the heatmaps to the same colormap range:

  • Either set vmin and vmax on each heatmap:

    sns.heatmap(df,  ax=axes[0], vmin=0, vmax=37, cbar=False)
    sns.heatmap(df2, ax=axes[1], vmin=0, vmax=37, cbar=False)
    sns.heatmap(df3, ax=axes[2], vmin=0, vmax=37)
    #                            ^       ^
    
  • Or create a min/max norm using matplotlib.colors.Normalize:

    import matplotlib.colors as mcolors
    norm = mcolors.Normalize(0, 37)
    
    sns.heatmap(df,  ax=axes[0], norm=norm, cbar=False)
    sns.heatmap(df2, ax=axes[1], norm=norm, cbar=False)
    sns.heatmap(df3, ax=axes[2], norm=norm)
    #                            ^
    

Note that if you want to extract vmin and vmax automatically, ravel and stack the data:

values = np.hstack([d.values.ravel() for d in [df, df2, df3]])
norm = mcolors.Normalize(values.min(), values.max()) 

Upvotes: 3

Related Questions