gmavrom
gmavrom

Reputation: 440

Maintain consistent subplot size for different layouts

I am facing what I thought would be a simple problem, but I am struggling to find a simple and scalable solution. Basically, I would like to make some figure in Matplotlib with different numbers of subplots and different layouts for each figure.

The specific requirement that I have for these figures is that I want all subplots, across all figures to have the same exact size.

The simplest solution that I have tried would be to scale the figsize according to the number of subplots that I have:

import numpy as np
import matplotlib.pyplot as plt


x = np.linspace(0, 2 * np.pi, 400)
y = np.sin(x ** 2)

fig, ax = plt.subplots(2, 2, figsize=(10,6))
for i in ax.flatten():
    i.plot(x, y)
plt.savefig('f1.pdf')

fig, ax = plt.subplots(3, 2, figsize=(10,9))
for i in ax.flatten():
    i.plot(x, y)
plt.savefig('f2.pdf')

fig, ax = plt.subplots(2, 3, figsize=(15,6))
for i in ax.flatten():
    i.plot(x, y)
plt.savefig('f3.pdf')

So in the code above, for the 2x2 layout, the figsize is set at 10in x 6in and, for instance, for the 3x2 layout at 10in x 9in.

This makes the subplots in each figure be quite similar in terms of their size, but not exactly the same (I check that by importing each figure in Adobe Illustrator and checking the axes dimensions).

Is there an easy and scalable approach that I can use to ensure the same subplot size in each figure for any arbitrary number of subplots and their layout? I would assume something where instead of specifying the figsize, I set the subplot size instead, but I have not figured anything out yet...

Any help will be appreciated!

Upvotes: 1

Views: 836

Answers (1)

Stef
Stef

Reputation: 30579

You may want to use an AxesDivider. The following example creates all axes 3.5" wide (Size.Fixed(3.5)) x 2.0" high (Size.Fixed(2)) and evenly (Size.Scaled(1) for all pads) splits the remaining space for the padding.

import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import Divider, Size    

x = np.linspace(0, 2 * np.pi, 400)
y = np.sin(x ** 2)

sc = Size.Scaled(1)
fh = Size.Fixed(3.5)
fv = Size.Fixed(2)

fig, ax = plt.subplots(2, 2, figsize=(10,6))
h = [sc, fh] * 2 + [sc]
v = [sc, fv] * 2 + [sc]
divider = Divider(fig, (0.0, 0.0, 1., 1.), h, v)
for i in range(2):
    for j in range(2):
        ax[i,j].set_axes_locator(divider.new_locator(nx=2*i+1, ny=2*j+1))
for i in ax.flatten():
    i.plot(x, y)
plt.savefig('f1.pdf')

fig, ax = plt.subplots(3, 2, figsize=(10,9))
h = [sc, fh] * 2 + [sc]
v = [sc, fv] * 3 + [sc]
divider = Divider(fig, (0.0, 0.0, 1., 1.), h, v)
for i in range(3):
    for j in range(2):
        ax[i,j].set_axes_locator(divider.new_locator(nx=2*j+1, ny=2*i+1))
for i in ax.flatten():
    i.plot(x, y)
plt.savefig('f2.pdf')

fig, ax = plt.subplots(2, 3, figsize=(15,6))
h = [sc, fh] * 3 + [sc]
v = [sc, fv] * 2 + [sc]
divider = Divider(fig, (0.0, 0.0, 1., 1.), h, v)
for i in range(2):
    for j in range(3):
        ax[i,j].set_axes_locator(divider.new_locator(nx=2*j+1, ny=2*i+1))
for i in ax.flatten():
    i.plot(x, y)
plt.savefig('f3.pdf')

Upvotes: 2

Related Questions