r_31415
r_31415

Reputation: 8972

Embedding several inset axes in another axis using matplotlib

Is it possible to embed a changing number of plots in a matplotlib axis? For example, the inset_axes method is used to place inset axes inside parent axes:

enter image description here

However, I have several rows of plots and I want to include some inset axes inside the last axis object of each row.

fig, ax = plt.subplots(2,4, figsize=(15,15))
for i in range(2):
    ax[i][0].plot(np.random.random(40))
    ax[i][2].plot(np.random.random(40))
    ax[i][3].plot(np.random.random(40))

    # number of inset axes
    number_inset = 5
    for j in range(number_inset):
        ax[i][4].plot(np.random.random(40))

enter image description here

Here instead of the 5 plots drawn in the last column, I want several inset axes containing a plot. Something like this:

enter image description here

The reason for this is that every row refers to a different item to be plotted and the last column is supposed to contain the components of such item. Is there a way to do this in matplotlib or maybe an alternative way to visualize this?

Thanks

Upvotes: 3

Views: 2838

Answers (2)

r_31415
r_31415

Reputation: 8972

This is what I did to obtain the same result without setting the number of inset plots in advance.

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import numpy as np

fig = plt.figure(figsize=(12,6))

nrows = 2
ncols = 4

# changing the shape of GridSpec's output
outer_grid = gridspec.GridSpec(nrows, ncols)
grid = []
for i in range(nrows*ncols):
    grid.append(outer_grid[i])
outer_grid = np.array(grid).reshape(nrows,ncols)

for i in range(nrows):
    inner_grid_1 = gridspec.GridSpecFromSubplotSpec(1, 1,
                subplot_spec=outer_grid[i][0])
    ax = plt.Subplot(fig, inner_grid_1[0])
    ax.plot(np.random.normal(0,1,50).cumsum())
    fig.add_subplot(ax)

    inner_grid_2 = gridspec.GridSpecFromSubplotSpec(1, 1,
                subplot_spec=outer_grid[i][1])
    ax2 = plt.Subplot(fig, inner_grid_2[0])
    ax2.plot(np.random.normal(0,1,50).cumsum())
    fig.add_subplot(ax2)

    inner_grid_3 = gridspec.GridSpecFromSubplotSpec(1, 1,
                subplot_spec=outer_grid[i][2])
    ax3 = plt.Subplot(fig, inner_grid_3[0])
    ax3.plot(np.random.normal(0,1,50).cumsum())
    fig.add_subplot(ax3)

    # this value can be set based on some other calculation depending 
    # on each row
    numinsets = 3 
    inner_grid_4 = gridspec.GridSpecFromSubplotSpec(numinsets, 1,
                subplot_spec=outer_grid[i][3])

    # Adding subplots to the last inner grid
    for j in range(inner_grid_4.get_geometry()[0]):
        ax4 = plt.Subplot(fig, inner_grid_4[j])
        ax4.plot(np.random.normal(0,1,50).cumsum())
        fig.add_subplot(ax4)

# Removing labels
for ax in fig.axes:
    ax.set(xticklabels=[], yticklabels=[])

fig.tight_layout()

enter image description here

Upvotes: 1

Joe Kington
Joe Kington

Reputation: 284582

As @hitzg mentioned, the most common way to accomplish something like this is to use GridSpec. GridSpec creates an imaginary grid object that you can slice to produce subplots. It's an easy way to align fairly complex layouts that you want to follow a regular grid.

However, it may not be immediately obvious how to use it in this case. You'll need to create a GridSpec with numrows * numinsets rows by numcols columns and then create the "main" axes by slicing it with intervals of numinsets.

In the example below (2 rows, 4 columns, 3 insets), we'd slice by gs[:3, 0] to get the upper left "main" axes, gs[3:, 0] to get the lower left "main" axes, gs[:3, 1] to get the next upper axes, etc. For the insets, each one is gs[i, -1].

As a complete example:

import numpy as np
import matplotlib.pyplot as plt

def build_axes_with_insets(numrows, numcols, numinsets, **kwargs):
    """
    Makes a *numrows* x *numcols* grid of subplots with *numinsets* subplots
    embedded as "sub-rows" in the last column of each row.

    Returns a figure object and a *numrows* x *numcols* object ndarray where
    all but the last column consists of axes objects, and the last column is a
    *numinsets* length object ndarray of axes objects.
    """
    fig = plt.figure(**kwargs)
    gs = plt.GridSpec(numrows*numinsets, numcols)

    axes = np.empty([numrows, numcols], dtype=object)
    for i in range(numrows):
        # Add "main" axes...
        for j in range(numcols - 1):
            axes[i, j] = fig.add_subplot(gs[i*numinsets:(i+1)*numinsets, j])

        # Add inset axes...
        for k in range(numinsets):
            m = k + i * numinsets
            axes[i, -1][k] = fig.add_subplot(gs[m, -1])

    return fig, axes

def plot(axes):
    """Recursive plotting function just to put something on each axes."""
    for ax in axes.flat:
        data = np.random.normal(0, 1, 100).cumsum()
        try:
            ax.plot(data)
            ax.set(xticklabels=[], yticklabels=[])
        except AttributeError:
            plot(ax)

fig, axes = build_axes_with_insets(2, 4, 3, figsize=(12, 6))
plot(axes)
fig.tight_layout()
plt.show()

enter image description here

Upvotes: 3

Related Questions