aaronsnoswell
aaronsnoswell

Reputation: 6261

Matplotlib: Shared axis for imshow images

I'm trying to plot multiple images with Matplotlib's imshow() method, and have them share a single y axis. Although the images have the same number of y pixels, the images don't end up the same height.

Demonstration code;


import numpy as np
import matplotlib.pyplot as plt

from scipy.stats import poisson


def ibp_oneparam(alpha, N):
    """One-parameter IBP"""

    # First customer
    Z = np.array([np.ones(poisson(alpha).rvs(1))], dtype=int)

    # ith customer
    for i in range(2, N+1):

        # Customer walks along previously sampled dishes
        z_i = []
        for previously_sampled_dish in Z.T:
            m_k = np.sum(previously_sampled_dish)
            if np.random.rand() >= m_k / i:
                # Customer decides to sample this dish
                z_i.append(1.0)
            else:
                # Customer decides to skip this dish
                z_i.append(0.0)

        # Customer decides to try some new dishes
        z_i.extend(np.ones(poisson(alpha / i).rvs(1)))
        z_i = np.array(z_i)

        # Add this customer to Z
        Z_new = np.zeros((
            Z.shape[0] + 1,
            max(Z.shape[1], len(z_i))
        ))
        Z_new[0:Z.shape[0], 0:Z.shape[1]] = Z
        Z = Z_new
        Z[i-1, :] = z_i

    return Z


np.random.seed(3)

N = 10
alpha = 2.0

#plt.figure(dpi=100)
fig, (ax1, ax2, ax3) = plt.subplots(
    1,
    3,
    dpi=100,
    sharey=True
)

Z = ibp_oneparam(alpha, N)
plt.sca(ax1)
plt.imshow(
    Z,
    extent=(0.5, Z.shape[1] + 0.5, len(Z) + 0.5, 0.5),
    cmap='Greys_r'
)
plt.ylabel("Customers")
plt.xlabel("Dishes")
plt.xticks(range(1, Z.shape[1] + 1))
plt.yticks(range(1, Z.shape[0] + 1))

Z = ibp_oneparam(alpha, N)
plt.sca(ax2)
plt.imshow(
    Z,
    extent=(0.5, Z.shape[1] + 0.5, len(Z) + 0.5, 0.5),
    cmap='Greys_r'
)
plt.xlabel("Dishes")
plt.xticks(range(1, Z.shape[1] + 1))

Z = ibp_oneparam(alpha, N)
plt.sca(ax3)
plt.imshow(
    Z,
    extent=(0.5, Z.shape[1] + 0.5, len(Z) + 0.5, 0.5),
    cmap='Greys_r'
)
plt.xlabel("Dishes")
plt.xticks(range(1, Z.shape[1] + 1))

plt.show()

Output;

Three subplots each showing a binary image

I expect these images to each have the same height, and have varying widths. How can I achieve this?

Aside: The code above is demonstrating the Indian Buffet Process. For the purposes of this post, consider the three images to be random binary matrices with the same number of rows, but variable numbers of columns.

Thank you,

Upvotes: 1

Views: 698

Answers (1)

kpie
kpie

Reputation: 11120

I got a decent result with grid-spec width_ratios.

"""fig, (ax1, ax2, ax3) = plt.subplots(
    1,
    3,
    dpi=100,
    sharey=True,
    constrained_layout=True
)"""

# I commented the above code and replaced with below.

import matplotlib.gridspec as gridspec
fig = plt.figure(constrained_layout=True)
gs = gridspec.GridSpec(ncols=3, nrows=1, figure=fig, width_ratios=[7./4.,1,6./4.])
ax1 = fig.add_subplot(gs[0,0])
ax2 = fig.add_subplot(gs[0,1])
ax3 = fig.add_subplot(gs[0,2])

Resulting

It's some what counter intuitive that you need to use width ratios to adjust the heights but in the context of a grid with multiple rows it makes sense that you can only scale columns independently by width. and rows independently by height. https://matplotlib.org/tutorials/intermediate/gridspec.html

Upvotes: 1

Related Questions