goodside
goodside

Reputation: 4629

Create matplotlib subplots without manually counting number of subplots?

When doing ad-hoc analysis in Jupyter Notebook, I often want to view sequences of transformations to some Pandas DataFrame as vertically stacked subplots. My usual quick-and-dirty method is to not use subplots at all, but create a new figure for each plot:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

df = pd.DataFrame({"a": range(100)})  # Some arbitrary DataFrame
df.plot(title="0 to 100")
plt.show()

df = df * -1  # Some transformation
df.plot(title="0 to -100")
plt.show()

df = df * 2  # Some other transformation
df.plot(title="0 to -200")
plt.show()

This method has limitations. The x-axis ticks are unaligned even when identically indexed (because the x-axis width depends on y-axis labels) and the Jupyter cell output contains several separate inline images, not a single one that I can save or copy-and-paste.

As far as I know, the proper solution is to use plt.subplots():

fig, axes = plt.subplots(3, figsize=(20, 9))

df = pd.DataFrame({"a": range(100)}) # Arbitrary DataFrame
df.plot(ax=axes[0], title="0 to 100")

df = df * -1 # Some transformation
df.plot(ax=axes[1], title="0 to -100")

df = df * 2 # Some other transformation
df.plot(ax=axes[2], title="0 to -200")

plt.tight_layout()
plt.show()

This yields exactly the output I'd like. However, it also introduces an annoyance that makes me use the first method by default: I have to manually count the number of subplots I've created and update this count in several different places as the code changes.

In the multi-figure case, adding a fourth plot is as simple as calling df.plot() and plt.show() a fourth time. With subplots, the equivalent change requires updating the subplot count, plus arithmetic to resize the output figure, replacing plt.subplots(3, figsize=(20, 9)) with plt.subplots(4, figsize=(20, 12)). Each newly added subplot needs to know how many other subplots already exist (ax=axes[0], ax=axes[1], ax=axes[2], etc.), so any additions or removals require cascading changes to the plots below.

This seems like it should be trivial to automate — it's just counting and multiplication — but I'm finding it impossible to implement with the matplotlib/pyplot API. The closest I can get is the following partial solution, which is terse enough but still requires explicit counting:

n_subplots = 3  # Must still be updated manually as code changes

fig, axes = plt.subplots(n_subplots, figsize=(20, 3 * n_subplots))
i = 0  # Counts how many subplots have been added so far 

df = pd.DataFrame({"a": range(100)}) # Arbitrary DataFrame
df.plot(ax=axes[i], title="0 to 100")
i += 1

df = df * -1 # Arbitrary transformation
df.plot(ax=axes[i], title="0 to -100")
i += 1

df = df * 2 # Arbitrary transformation
df.plot(ax=axes[i], title="0 to -200")
i += 1

plt.tight_layout()
plt.show()

The root problem is that any time df.plot() is called, there must exist an axes list of known size. I considered delaying the execution of df.plot() somehow, e.g. by appending to a list of lambda functions that can be counted before they're called in sequence, but this seems like an extreme amount of ceremony just to avoid updating an integer by hand.

Is there a more convenient way to do this? Specifically, is there a way to create a figure with an "expandable" number of subplots, suitable for ad-hoc/interactive contexts where the count is not known in advance?

(Note: This question may appear to be a duplicate of either this question or this one, but the accepted answers to both questions contain exactly the problem I'm trying to solve — that the nrows= parameter of plt.subplots() must be declared before adding subplots.)

Upvotes: 3

Views: 1723

Answers (3)

Stef
Stef

Reputation: 30579

First create an empty figure and then add subplots using add_subplot. Update the subplotspecs of the existing subplots in the figure using a new GridSpec for the new geometry (the figure keyword is only needed if you're using constrained layout instead of tight layout).

import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd


def append_axes(fig, as_cols=False):
    """Append new Axes to Figure."""
    n = len(fig.axes) + 1
    nrows, ncols = (1, n) if as_cols else (n, 1)
    gs = mpl.gridspec.GridSpec(nrows, ncols, figure=fig)
    for i,ax in enumerate(fig.axes):
        ax.set_subplotspec(mpl.gridspec.SubplotSpec(gs, i))
    return fig.add_subplot(nrows, ncols, n)


fig = plt.figure(layout='tight')

df = pd.DataFrame({"a": range(100)}) # Arbitrary DataFrame
df.plot(ax=append_axes(fig), title="0 to 100")

df = df * -1 # Some transformation
df.plot(ax=append_axes(fig), title="0 to -100")

df = df * 2 # Some other transformation
df.plot(ax=append_axes(fig), title="0 to -200")

enter image description here

Example for adding the new subplots as columns (and using constrained layout for a change):

fig = plt.figure(layout='constrained')

df = pd.DataFrame({"a": range(100)}) # Arbitrary DataFrame
df.plot(ax=append_axes(fig, True), title="0 to 100")

df = df + 10 # Some transformation
df.plot(ax=append_axes(fig, True), title="10 to 110")

enter image description here

Upvotes: 4

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339230

You can create an object that stores the data and only creates the figure once you tell it to do so.

import pandas as pd
import matplotlib.pyplot as plt

class AxesStacker():
    def __init__(self):
        self.data = []
        self.titles = []

    def append(self, data, title=""):
        self.data.append(data)
        self.titles.append(title)

    def create(self):
        nrows = len(self.data)
        self.fig, self.axs = plt.subplots(nrows=nrows)
        for d, t, ax in zip(self.data, self.titles, self.axs.flat):
            d.plot(ax=ax, title=t)



stacker = AxesStacker()

df = pd.DataFrame({"a": range(100)})  # Some arbitrary DataFrame
stacker.append(df, title="0 to 100")

df = df * -1  # Some transformation
stacker.append(df, title="0 to -100")

df = df * 2  # Some other transformation
stacker.append(df, title="0 to -200")

stacker.create()
plt.show()

Upvotes: 1

Chris Adams
Chris Adams

Reputation: 18647

IIUC you need some container for your transformations to achieve this - a list for example. Something like:

arbitrary_trx = [
    lambda x: x,         # No transformation
    lambda x: x * -1,    # Arbitrary transformation
    lambda x: x * 2]     # Arbitrary transformation

fig, axes = plt.subplots(nrows=len(arbitrary_trx))

for ax, f in zip(axes, arbitrary_trx):
    df = df.apply(f)
    df.plot(ax=ax)

enter image description here

Upvotes: 1

Related Questions