Rodrigo
Rodrigo

Reputation: 147

Plotting multiple seaborn heatmaps with individual color bar

Is it possible to plot multiple seaborn heatmaps into a single figure, with a shared yticklabel, and individual color bars, like the figure below?

enter image description here

What I can do is to plot the heatmaps individually, using the following code:

#Figure 1

plt.figure()
sns.set()
comp = sns.heatmap(df, cmap="coolwarm", linewidths=.5, xticklabels=True, yticklabels=True, cbar_kws={"orientation": "horizontal", "label": "Pathway completeness", "pad": 0.004})
comp.set_xticklabels(comp.get_xticklabels(), rotation=-90)
comp.xaxis.tick_top() # x axis on top
comp.xaxis.set_label_position('top')
cbar = comp.collections[0].colorbar
cbar.set_ticks([0, 50, 100])
cbar.set_ticklabels(['0%', '50%', '100%'])          
figure = comp.get_figure()
figure.savefig("hetmap16.png", format='png', bbox_inches='tight')

#Figure 2 (figure 3 is the same, but with a different database)

plt.figure()
sns.set()
df = pd.DataFrame(heatMapFvaMinDictP)
fvaMax = sns.heatmap(df, cmap="rocket_r", linewidths=.5, xticklabels=True, cbar_kws={"orientation": "horizontal", "label": "Minimum average flux", "pad": 0.004})
fvaMax.set_xticklabels(fvaMax.get_xticklabels(), rotation=-90)
fvaMax.xaxis.tick_top() # x axis on top
fvaMax.xaxis.set_label_position('top')
fvaMax.tick_params(axis='y', labelleft=False)
figure = fvaMax.get_figure()
figure.savefig("fva1.png", format='png', bbox_inches='tight')

Upvotes: 1

Views: 2798

Answers (2)

JohanC
JohanC

Reputation: 80299

Seaborn builds upon matplotlib, which can be used for further customizing plots. plt.subplots(ncols=3, sharey=True, ...) creates three subplots with a shared y-axis. Adding ax=ax1 to sns.heatmap(..., ax=...) creates the heatmap on the desired subplot. Note that the return value of sns.heatmap is again that same ax.

The following code shows an example. vmin and vmax are explicitly set for the first heatmap to make sure that both values will appear in the colorbar (the default colorbar runs between the minimum and maximum of the encountered values).

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

sns.set()
fig, (ax1, ax2, ax3) = plt.subplots(ncols=3, sharey=True, figsize=(20, 8))

N = 20
labels = [''.join(np.random.choice(list('abcdefghi '), 40)) for _ in range(N)]
df = pd.DataFrame({'column 1': np.random.uniform(0, 100, N), 'column 2': np.random.uniform(0, 100, N)},
                  index=labels)
sns.heatmap(df, cmap="coolwarm", linewidths=.5, xticklabels=True, yticklabels=True, ax=ax1, vmin=0, vmax=100,
            cbar_kws={"orientation": "horizontal", "label": "Pathway completeness", "pad": 0.004})
ax1.set_xticklabels(ax1.get_xticklabels(), rotation=-90)
ax1.xaxis.tick_top()  # x axis on top
ax1.xaxis.set_label_position('top')
cbar = ax1.collections[0].colorbar
cbar.set_ticks([0, 50, 100])
cbar.set_ticklabels(['0%', '50%', '100%'])

for ax in (ax2, ax3):
    max_value = 10 if ax == ax2 else 1000
    df = pd.DataFrame({'column 1': np.random.uniform(0, max_value, N), 'column 2': np.random.uniform(0, max_value, N)},
                      index=labels)
    sns.heatmap(df, cmap="rocket_r", linewidths=.5, xticklabels=True, ax=ax,
                cbar_kws={"orientation": "horizontal", "pad": 0.004,
                          "label": ("Minimum" if ax == ax2 else "Minimum") + " average flux"})
    ax.set_xticklabels(ax.get_xticklabels(), rotation=-90)
    ax.xaxis.tick_top()  # x axis on top
    ax.xaxis.set_label_position('top')

plt.tight_layout()
fig.savefig("subplots.png", format='png', bbox_inches='tight')
plt.show()

example plot

Upvotes: 1

StupidWolf
StupidWolf

Reputation: 46898

You can concatenate the two dataframes and use FacetGrid with FacetGrid.map_dataframe, and I guess you might need to adjust the aesthetics a bit. Don't have your data so I try it with an example data:

import pandas as pd
import numpy as np
import seaborn as sns

np.random.seed(111)
df1 = pd.DataFrame({'A':np.random.randn(15),'B':np.random.randn(15)},
                   index=['row_variable'+str(i+1) for i in range(15)])

df2 = pd.DataFrame({'A':np.random.randn(15),'B':np.random.randn(15)},
                   index=['row_variable'+str(i+1) for i in range(15)])

We annotate the data.frames with a column indicating the database like you have, and also set a dictionary for the color schemes for each dataframes:

df1['database'] = "database1"
df2['database'] = "database2"

dat = pd.concat([df1,df2])
cdict = {'database1':'rocket_r','database2':'coolwarm'}

And define a function to draw the heatmap:

def heat(data,color):
    sns.heatmap(data[['A','B']],cmap=cdict[data['database'][0]],
                cbar_kws={"orientation": "horizontal"})

Then facet:

fg = sns.FacetGrid(data=dat, col='database',aspect=0.7,height=4)
fg.map_dataframe(heat)

enter image description here

Upvotes: 1

Related Questions