Leonardo
Leonardo

Reputation: 4228

Centering a table with a heatmap

I'm trying to add a matplotlib table under a seaborn heatmap. I've been able to plot them but no luck with the alignment.

# Main data
df = pd.DataFrame({"A": [20, 10, 7, 39], 
                   "B": [1, 8, 12, 9], 
                   "C": [780, 800, 1200, 250]})

# It contains min and max values for the df cols
df_info =  pd.DataFrame({"A": [22, 35], 
                   "B": [5, 10], 
                   "C": [850, 900]})

df_norm = (df - df.min())/(df.max() - df.min())


# Plot the heatmap
vmin = df_norm.min().min()
vmax = df_norm.max().max()

fig, (ax1, ax2) = plt.subplots(nrows=2, sharex=True)
sns.heatmap(df_norm, ax=ax1, annot=df, cmap='RdBu_r', cbar=True)

When I add the table to the ax2 it's plotted taking all the width of the heatmap (color bar included). I've tried all possible combination of locor bbox but I haven't been able to center exactly the table and giving it the same width (of the whole table but also of the single cells) of the heatmap at the top.

table = df_info
cell_text = []
for row in range(len(table)):
    cell_text.append(table.iloc[row])

ax2.axis('off')
ax2.table(cellText=cell_text,
          rowLabels=table.index,
          colLabels=None,
          loc='center')

enter image description here

Also sometimes I pass to the heatmap the param square=True to print squared cells and the result is this:

enter image description here

Question: How can I attach and center the table and its cells to the heatmap?

EDIT: Technically tdy's answer is correct to solve the problem of my simple example. Although I might have over-simplified it and left out an important piece of information.
In my real case scenario, if I create the figure with:

fig, (ax1, ax2) = plt.subplots(nrows=2,
                               **{"figsize": (18, 18),
                                  "dpi": 200,
                                  "tight_layout": True})

and apply the above mentioned answer I obtain something like this:

enter image description here

Where the table is way far at the bottom and wider than the heatmap.

Also if I set tight_layout=False I obtain a table with the correct width but still very far at the bottom:

fig, (ax1, ax2) = plt.subplots(nrows=2,
                               **{"figsize": (18, 18),
                                  "dpi": 200,
                                  "tight_layout": False})

enter image description here

I guess in my case the "figsize": (18, 18) and tight_layout have big impact on my problem but I'm not sure neither why nor how to solve it.

Upvotes: 2

Views: 730

Answers (1)

tdy
tdy

Reputation: 41347

Short method

You can move/resize the table with Axes.set_position(). The left/bottom/width/height params can be tweaked as needed:

bbox1 = ax1.get_position()
bbox2 = ax2.get_position()

# modify as needed
left = bbox1.x0
bottom = bbox1.y0 - (bbox2.height * 0.8)
width = bbox1.x0 + (bbox1.width * 0.8)
height = bbox2.height

ax2.set_position([left, bottom, width, height])

heatmap and table aligned with set_position


Longer method

If the simple method isn't working well, try setting the axes height_ratios via gridspec_kw in plt.subplots(). Also I needed to set tight_layout=False.

Use height_ratios to set the ratio of ax1:ax2 (heatmap:table), and use the *_offset variables to adjust the table's size/position as needed. These were values that worked for my system, but you can tweak for your system:

### modify these params as needed ###

height_ratios = (20, 1) # heatmap:table ratio (20:1)
left_offset = 0         # table left position adjustment
bottom_offset = -0.025  # table bottom position adjustment
width_offset = -0.0005  # table width adjustment
height_offset = 0       # table height adjusment

#####################################

fig_kw = dict(figsize=(18, 18), dpi=200, tight_layout=False)
gridspec_kw = dict(height_ratios=height_ratios)
fig, (ax1, ax2) = plt.subplots(nrows=2, gridspec_kw=gridspec_kw, **fig_kw)

sns.heatmap(df_norm, ax=ax1, annot=df, cmap='RdBu_r', cbar=True)
ax2.table(cellText=[df_info.iloc[row] for row in range(len(df_info))],
          rowLabels=table.index,
          colLabels=None,
          loc='center')
ax2.axis('off')

bbox1 = ax1.get_position()
bbox2 = ax2.get_position()

left = bbox1.x0 + left_offset
bottom = bbox1.y0 - bbox2.height + bottom_offset
width = bbox1.width + width_offset
height = bbox2.height + height_offset

ax2.set_position([left, bottom, width, height])

heatmap and table aligned with gridspec_kw and set_position

Upvotes: 3

Related Questions