user248237
user248237

Reputation:

how to make square subplots in matplotlib with heatmaps?

I'm trying to make a simple subplot with a dendrogram in one subplot and a heat map in another, while maintaining square axes. I try the following:

from scipy.cluster.hierarchy import linkage
from scipy.cluster.hierarchy import dendrogram
from scipy.spatial.distance import pdist

fig = plt.figure(figsize=(7,7))
plt.subplot(2, 1, 1)
cm = matplotlib.cm.Blues
X = np.random.random([5,5])
pmat = pdist(X, "euclidean")
linkmat = linkage(pmat)
dendrogram(linkmat)
plt.subplot(2, 1, 2)
labels = ["a", "b", "c", "d", "e", "f"]
Y = np.random.random([6,6])
plt.xticks(arange(0.5, 7.5, 1))
plt.gca().set_xticklabels(labels)
plt.pcolor(Y)
plt.colorbar()

this yields the following:

enter image description here

but the problems are that the axes are not square, and the colorbar is considered part of the second subplot. I'd like it instead to hang outside the plot, and make it so the dendrogram box and the heatmap box are both square and aligned with each other (i.e. same size.)

I tried using aspect='equal' to get square axes when calling subplot as the documentation suggests, but this ruined the plot, giving this...

enter image description here

if I try to use plt.axis('equal') after each subplot instead of aspect='equal', it strangely squares the heatmap but not its bounding box (see below), while destroying the dendrogram altogether and also messing up the alignment of the xtick labels.... - giving rise to this mess:

enter image description here

how can this be fixed? to summarize, I'm trying to plot something very simple: a square dendrogram in the top subplot, and a square heatmap in the bottom subplot, with the color bar on the right. nothing fancy.

finally, more general question: is there a general rule / principle to follow to force matplotlib to always make axes square? I cannot think of a single case where I don't want square axes but it's usually not the default behavior. I'd like to force all plots to be square if possible.

Upvotes: 13

Views: 18733

Answers (3)

mobeets
mobeets

Reputation: 460

To add to the other answers, you need to take the absolute value of arguments to .set_aspect:

x0,x1 = ax1.get_xlim()
y0,y1 = ax1.get_ylim()
ax1.set_aspect(abs(x1-x0)/abs(y1-y0))

Upvotes: 1

pelson
pelson

Reputation: 21839

@HYRY's answer is very good and deserves all the credit. But to finish off the answer about lining the squared plots up nicely, you could trick matplotlib into thinking that both plots have colorbars, only making the first one invisible:

from scipy.cluster.hierarchy import linkage
from scipy.cluster.hierarchy import dendrogram
from scipy.spatial.distance import pdist
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
from numpy import arange

fig = plt.figure(figsize=(5,7))
ax1 = plt.subplot(2, 1, 1)
cm = matplotlib.cm.Blues
X = np.random.random([5,5])
pmat = pdist(X, "euclidean")
linkmat = linkage(pmat)
dendrogram(linkmat)
x0,x1 = ax1.get_xlim()
y0,y1 = ax1.get_ylim()
ax1.set_aspect((x1-x0)/(y1-y0))

plt.subplot(2, 1, 2, aspect=1)
labels = ["a", "b", "c", "d", "e", "f"]
Y = np.random.random([6,6])
plt.xticks(arange(0.5, 7.5, 1))
plt.gca().set_xticklabels(labels)
plt.pcolor(Y)
plt.colorbar()

# add a colorbar to the first plot and immediately make it invisible
cb = plt.colorbar(ax=ax1)
cb.ax.set_visible(False)

plt.show()

code output

Upvotes: 10

HYRY
HYRY

Reputation: 97281

aspect="equal" mean the same length in data space will be the same length in screen space, but in your top axe, the data ranges of xaxis and yaxis are not the same, so it will not be a square. To fix this problem, you can set the aspect to the ratio of xaxis range and yaxis range:

from scipy.cluster.hierarchy import linkage
from scipy.cluster.hierarchy import dendrogram
from scipy.spatial.distance import pdist
import matplotlib
from matplotlib import pyplot as plt
import numpy as np
from numpy import arange

fig = plt.figure(figsize=(5,7))
ax1 = plt.subplot(2, 1, 1)
cm = matplotlib.cm.Blues
X = np.random.random([5,5])
pmat = pdist(X, "euclidean")
linkmat = linkage(pmat)
dendrogram(linkmat)
x0,x1 = ax1.get_xlim()
y0,y1 = ax1.get_ylim()
ax1.set_aspect((x1-x0)/(y1-y0))
plt.subplot(2, 1, 2, aspect=1)
labels = ["a", "b", "c", "d", "e", "f"]
Y = np.random.random([6,6])
plt.xticks(arange(0.5, 7.5, 1))
plt.gca().set_xticklabels(labels)
plt.pcolor(Y)
plt.colorbar()

Here is the output:

enter image description here

To location the colorbar we need write a ColorBarLocator Class, the pad and width argument are in pixel unit,

  • pad: set the space between the axes and it's colobar
  • width: the width of the colorbar

replace plt.colorbar() with the following code:

class ColorBarLocator(object):
    def __init__(self, pax, pad=5, width=10):
        self.pax = pax
        self.pad = pad
        self.width = width

    def __call__(self, ax, renderer):
        x, y, w, h = self.pax.get_position().bounds
        fig = self.pax.get_figure()
        inv_trans = fig.transFigure.inverted()
        pad, _ = inv_trans.transform([self.pad, 0])
        width, _ = inv_trans.transform([self.width, 0])
        return [x+w+pad, y, width, h]

cax = fig.add_axes([0,0,0,0], axes_locator=ColorBarLocator(ax2))
plt.colorbar(cax = cax)

enter image description here

Upvotes: 15

Related Questions