Jason
Jason

Reputation: 181

Index error within a subplot command in matplotlib

I am not sure if this is a small bug to the source code, but the subplots will not scatter plot two arrays shaped (1018,) a piece. Both of these arrays are results from an OLS regression. I've never had problems performing a scatter plot with the variables until I used the specified subplot commands. Here is my code below:

fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,5))

axes[0,0].scatter(Blodgett_wue_obs,Blodgett_wue_results.fittedvalues,color ='blue')
axes[0,0].plot(Blodgett_wue_obs,Blodgett_wue_obs,'r')


File "<ipython-input-311-527571e09d59>", line 1, in <module>
runfile('/Users/JasonDucker/Documents/forJason/Paper_Plots.py', wdir='/Users/JasonDucker/Documents/forJason')

File "/Users/JasonDucker/anaconda/lib/python2.7/site-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 685, in runfile
execfile(filename, namespace)

File "/Users/JasonDucker/anaconda/lib/python2.7/site-packages/spyderlib/widgets/externalshell/sitecustomize.py", line 78, in execfile
builtins.execfile(filename, *where)

File "/Users/JasonDucker/Documents/forJason/Paper_Plots.py", line 362, in <module>
axes[0,0].scatter(Blodgett_wue_obs,Blodgett_wue_results.fittedvalues,color ='blue')

IndexError: too many indices for array

Any thoughts on this issue would be much appreciated!

Upvotes: 0

Views: 3267

Answers (1)

xnx
xnx

Reputation: 25518

If you specify one row and three columns, axes has shape (3,), not (1,3), so you index the figures at [0], [1], [2].

In [8]: fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,5))
In [9]: axes.shape
Out[9]: (3,)

The best way to get plt.subplots to always return a two-dimensional array is to set squeeze=False:

In [8]: fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,5), squeeze=False)
In [9]: axes.shape
Out[9]: (1,3)

Alternatively (if you already have axes as a one-dimensional array), you could use:

axes = np.atleast_2d(axes)

For example,

In [8]: fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(12,5))
In [9]: axes.shape
Out[9]: (3,)
In[10]: axes = np.atleast_2d(axes)
In [11]: axes.shape
Out[11]: (1, 3)

In [12]: axes[0,0]
Out[12]: <matplotlib.axes._subplots.AxesSubplot at 0x1149499d0>

Upvotes: 5

Related Questions