whytheq
whytheq

Reputation: 35567

How does the indexing of subplots work

I have the following:

import matplotlib.pyplot as plt

fig = plt.figure()

for i in range(10):
    ax = fig.add_subplot(551 + i)
    ax.plot([1,2,3,4,5], [10,5,10,5,10], 'r-')

I was imagining that the 55 means that it is creating a grid that is 5 subplots wide and 5 subplots deep - so can cater for 25 subplots?

The for loop will just iterate 10 times - so I thought (obviously wrongly) that 25 possible plots would accomodate those iterations ok but I get the following:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-118-5775a5ea6c46> in <module>()
     10 
     11 for i in range(10):
---> 12     ax = fig.add_subplot(551 + i)
     13     ax.plot([1,2,3,4,5], [10,5,10,5,10], 'r-')
     14 

/home/blah/anaconda/lib/python2.7/site-packages/matplotlib/figure.pyc in add_subplot(self, *args, **kwargs)
   1003                     self._axstack.remove(ax)
   1004 
-> 1005             a = subplot_class_factory(projection_class)(self, *args, **kwargs)
   1006 
   1007         self._axstack.add(key, a)

/home/blah/anaconda/lib/python2.7/site-packages/matplotlib/axes/_subplots.pyc in __init__(self, fig, *args, **kwargs)
     62                     raise ValueError(
     63                         "num must be 1 <= num <= {maxn}, not {num}".format(
---> 64                             maxn=rows*cols, num=num))
     65                 self._subplotspec = GridSpec(rows, cols)[int(num) - 1]
     66                 # num - 1 for converting from MATLAB to python indexing

ValueError: num must be 1 <= num <= 30, not 0

Upvotes: 7

Views: 25991

Answers (2)

TheBlackCat
TheBlackCat

Reputation: 10308

Although tom answered your question, in this sort of situation you should be using fig, axs = plt.subplots(n, m). This will create a new figure with the n rows and m columns of subplots. fig is the figure created. axs is a 2D numpy array where each element in the array is the subplot in the corresponding location in the figure. So the top-right element axs is the top-right subplot in the figure. You can access the subplots through normal indexing, or loop over them.

So in your case you can do

import matplotlib.pyplot as plt

# axs is a 5x5 numpy array of axes objects
fig, axs = plt.subplots(5, 5)

# "ravel" flattens the numpy array without making a copy
for ax in axs.ravel():
    ax.plot([1,2,3,4,5], [10,5,10,5,10], 'r-')

Upvotes: 6

tmdavison
tmdavison

Reputation: 69136

In the convience shorthand notation, the 55 does mean there are 5 rows and 5 columns. However, the shorthand notation only works for single-digit integers (i.e. for nrows, ncols and plot_number all less than 10).

You can expand it to full notation (i.e. use commas: add_subplot(nrows, ncols, plot_number)) and then all will work fine for you:

for i in range(10):
    ax = fig.add_subplot(5, 5, 1 + i)
    ax.plot([1,2,3,4,5], [10,5,10,5,10], 'r-')

From the docs for plt.subplot (which uses the same args as fig.add_subplot) :

Typical call signature:

subplot(nrows, ncols, plot_number) 

Where nrows and ncols are used to notionally split the figure into nrows * ncols sub-axes, and plot_number is used to identify the particular subplot that this function is to create within the notional grid. plot_number starts at 1, increments across rows first and has a maximum of nrows * ncols.

In the case when nrows, ncols and plot_number are all less than 10, a convenience exists, such that the a 3 digit number can be given instead, where the hundreds represent nrows, the tens represent ncols and the units represent plot_number.

Upvotes: 9

Related Questions