Gamp
Gamp

Reputation: 319

Python subplot used to show one figure

When I was using the subplot, and tried to use the following subplot to plot only one figure, it will give error:

AttributeError: 'AxesSubplot' object has no attribute 'flat'

fig, ax = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
for i, ax in enumerate(ax.flat):
    ax.plot(X, Y, color='k')

How to solve this issue if I want to arbitrarily set the number of sub figures?How can I easily understand ax.flat?

Upvotes: 0

Views: 5184

Answers (4)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339560

There is precisely one case where the code

fig, ax = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
for i, ax in enumerate(ax.flat):
    ax.plot(X, Y, color='k')

would not work as expected. This is for nrows = ncols = 1. This is because for one single row and column, ax is a single subplot, not an array of several subplots.

To circumvent this problem, and to be able to use the same code without knowing nrows and ncols in advance, use the squeeze=False option. This will ensure ax is always an array and hence has a .flat attribute. For better understandability don't call the axes array by the same name as the axes itself.

fig, axs = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False, figsize=figsize)
for i, ax in enumerate(axs.flat):
    ax.plot(X, Y, color='k')

Upvotes: 6

josemz
josemz

Reputation: 1312

flat is an attribute of numpy arrays which returns an iterator. For example, if you have a 2d array like this:

import numpy as np
arr2d = np.arange(4).reshape(2, 2)
arr2d
# array([[0, 1],
#        [2, 3]])

the flat attribute is provided as a convenient way to iterate through this array as if it was a 1d array:

for value in arr2d.flat:
    print(value)
# 0
# 1
# 2
# 3

You can also flatten the array with the flatten method:

arr2d.flatten()
# array([0, 1, 2, 3])

So going back to your question, when you specify:

  • ncols to 1 and nrows to a value greater than 1 or the other way around, you get the axes in a 1d numpy array, in which case the flat attribute returns the same array.
  • both ncols and nrows to values greater than 1, you get the axes in a 2d array, in which case the flat attribute returns the flattened array.
  • both ncols and nrows to 1, you get the axes object, which doesn't have a flat attribute.

So a possible solution would be to turn your ax object into a numpy array everytime:

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)

ax = np.array(ax)
for i, axi in enumerate(ax.flat):
    axi.plot(...)

Upvotes: 0

Sheldore
Sheldore

Reputation: 39072

You can use either ax.ravel() or ax.flatten(). Below is a simple example

import matplotlib.pyplot as plt

fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 6))

# for i, ax in enumerate(ax.flatten()):
for i, ax in enumerate(ax.ravel()):
    ax.plot([1,2,3], color='k')
plt.show()

enter image description here

Upvotes: 0

swatchai
swatchai

Reputation: 18812

When you create a set of subplots having multiple rows/columns with the command fig,ax=plt.subplots(), it returns a fig and a list of axes ax. The shape of ax list is 2 dimensions (rows, cols). That is the reason why you need to flatten it to become 1 dimension when you iterate the ax list. To get access to a specific axes, you need the row/column indices, for example, ax[r][c] is the axes on (r+1)th row/ (c+1)th column. The indices are zero based. The working code below demonstrate how to do it.

import matplotlib.pyplot as plt
import numpy as np

nrows,ncols = 3,2
figsize = [5,9]
X = np.random.rand(6)
Y = np.random.rand(6)

fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
for i, axi in enumerate(ax.flat):
    axi.plot(X, Y, color='k')
    rowid = i // ncols
    colid = i % ncols
    axi.set_title("row:"+str(rowid)+",col:"+str(colid))

# You can access the axes by row_id, col_id.
# Now let's plot on ax[row_id][col_id] of your choice
ax[0][1].plot(Y,X,color='red')    # plot 2nd line in red
ax[2][0].plot(Y,X,color='green')  # plot 2nd line in green

plt.tight_layout(True)
plt.show()

The output plot:

enter image description here

Upvotes: 0

Related Questions