lfo-po
lfo-po

Reputation: 53

Seaborn heatmaps in subplots - align x-axis

I am trying to plot a figure containing two subplots, a seaborn heatmap and simple matplotlib lines. However, when sharing the x-axis for both plots, they do not align as can be seen in this figure:

figure

It would seem that the problem is similar to this post, but when displaying ax[0].get_xticks() and ax[1].get_xticks() I get the same positions, so I don't know what to change. And in my picture the the deviation seems to be more than a 0.5 shift.

What am I doing wrong?

The code I used to plot the figure is the following:

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

M_1=np.random.random((15,15))
M_2=np.random.random((15,15))

L_1=np.random.random(15)
L_2=np.random.random(15)

x=range(15)

cmap = sns.color_palette("hot", 100)
sns.set(style="white")

fig, ax  = plt.subplots(2, 1, sharex='col', figsize=(10, 12))

ax[0].plot(x,L_1,'-', marker='o',color='tab:orange')
sns.heatmap(M_1, cmap=cmap, vmax=np.max(M_1), center=np.max(M_1)/2., square=False, ax=ax[1])

Upvotes: 5

Views: 5334

Answers (1)

Diziet Asahi
Diziet Asahi

Reputation: 40697

@Mr-T 's comment is spot on. The easiest would be to create the axes beforehand instead of letting heatmap() shrink your axes in order to make room for the colorbar.

There is the added complication that the labels for the heatmap are not actually placed at [0,1,...] but are in the middle of each cell at [0.5, 1.5, ...]. So if you want your upper plot to align with the labels at the bottom (and with the center of each cell), you may have to shift your plot by 0.5 units to the right:

M_1=np.random.random((15,15))
M_2=np.random.random((15,15))
L_1=np.random.random(15)
L_2=np.random.random(15)
x=np.arange(15)

cmap = sns.color_palette("hot", 100)
sns.set(style="white")

fig, ax  = plt.subplots(2, 2, sharex='col', gridspec_kw={'width_ratios':[100,5]})
ax[0,1].remove()  # remove unused upper right axes
ax[0,0].plot(x+0.5,L_1,'-', marker='o',color='tab:orange')
sns.heatmap(M_1, cmap=cmap, vmax=np.max(M_1), center=np.max(M_1)/2., square=False, ax=ax[1,0], cbar_ax=ax[1,1])

enter image description here

Upvotes: 6

Related Questions