Dark Apostle
Dark Apostle

Reputation: 197

set custom tick labels on heatmap color bar

I have a list of dataframes named merged_dfs that I am looping through to get the correlation and plot subplots of heatmap correlation matrix using seaborn.

I want to customize the colorbar tick labels, but I am having trouble figuring out how to do it with my example.

Currently, my colorbar scale values from top to bottom are

[1,0.5,0,-0.5,-1]

I want to keep these values, but change the tick labels to be

[1,0.5,0,0.5,1] 

for my diverging color bar.

Here is the code and my attempt:

fig, ax = plt.subplots(nrows=6, ncols=2, figsize=(20,20))
for i, (title,merging) in enumerate (zip(new_name_data,merged_dfs)):
    graph = merging.corr()
    colormap = sns.diverging_palette(250, 250, as_cmap=True)
    a = sns.heatmap(graph.abs(), cmap=colormap, vmin=-1,vmax=1,center=0,annot = graph, ax=ax.flat[i])
    cbar = fig.colorbar(a)
    cbar.set_ticklabels(["1","0.5","0","0.5","1"])
fig.delaxes(ax[5,1])
        
plt.show()
plt.close()

I keep getting this error:

AttributeError: 'AxesSubplot' object has no attribute 'get_array'

Upvotes: 1

Views: 6306

Answers (1)

JohanC
JohanC

Reputation: 80329

Several things are going wrong:

  • fig.colorbar(...) would create a new colorbar, by default appended to the last subplot that was created.
  • sns.heatmap returns an ax (indicates a subplot). This is very different to matplotlib functions, e.g. plt.imshow(), which would return the graphical element that was plotted.
  • You can suppress the heatmap's colorbar (cbar=False), and then create it newly with the parameters you want.
  • fig.colorbar(...) needs a parameter ax=... when the figure contains more than one subplot.
  • Instead of creating a new colorbar, you can add the colorbar parameters to sns.heatmap via cbar_kws=.... The colorbar itself can be found via ax.collections[0].colobar. (ax.collections[0] is where matplotlib stored the graphical object that contains the heatmap.)
  • Using an index is strongly discouraged when working with Python. It's usually more readable, easier to maintain and less error-prone to include everything into the zip command.
  • As now your vmin now is -1, taking the absolute value for the coloring seems to be a mistake.
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

merged_dfs = [pd.DataFrame(data=np.random.rand(5, 7), columns=[*'ABCDEFG']) for _ in range(5)]
new_name_data = [f'Dataset {i + 1}' for i in range(len(merged_dfs))]

fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(12, 7))
for title, merging, ax in zip(new_name_data, merged_dfs, axes.flat):
     graph = merging.corr()
     colormap = sns.diverging_palette(250, 250, as_cmap=True)
     sns.heatmap(graph, cmap=colormap, vmin=-1, vmax=1, center=0, annot=True, ax=ax, cbar_kws={'ticks': ticks})
     ax.collections[0].colorbar.set_ticklabels([abs(t) for t in ticks])

fig.delaxes(axes.flat[-1])
fig.tight_layout()
plt.show()

sns.heatmap with changed colorbar

Upvotes: 3

Related Questions