Reputation: 319
When I was using the subplot, and tried to use the following subplot to plot only one figure, it will give error:
AttributeError: 'AxesSubplot' object has no attribute 'flat'
fig, ax = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
for i, ax in enumerate(ax.flat):
ax.plot(X, Y, color='k')
How to solve this issue if I want to arbitrarily set the number of sub figures?How can I easily understand ax.flat
?
Upvotes: 0
Views: 5184
Reputation: 339560
There is precisely one case where the code
fig, ax = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
for i, ax in enumerate(ax.flat):
ax.plot(X, Y, color='k')
would not work as expected. This is for nrows = ncols = 1
. This is because for one single row and column, ax
is a single subplot, not an array of several subplots.
To circumvent this problem, and to be able to use the same code without knowing nrows
and ncols
in advance, use the squeeze=False
option. This will ensure ax
is always an array and hence has a .flat
attribute. For better understandability don't call the axes array by the same name as the axes itself.
fig, axs = plt.subplots(nrows=nrows, ncols=ncols, squeeze=False, figsize=figsize)
for i, ax in enumerate(axs.flat):
ax.plot(X, Y, color='k')
Upvotes: 6
Reputation: 1312
flat
is an attribute of numpy
arrays which returns an iterator. For example, if you have a 2d array like this:
import numpy as np
arr2d = np.arange(4).reshape(2, 2)
arr2d
# array([[0, 1],
# [2, 3]])
the flat
attribute is provided as a convenient way to iterate through this array as if it was a 1d array:
for value in arr2d.flat:
print(value)
# 0
# 1
# 2
# 3
You can also flatten the array with the flatten
method:
arr2d.flatten()
# array([0, 1, 2, 3])
So going back to your question, when you specify:
ncols
to 1 and nrows
to a value greater than 1 or the other way around, you get the axes in a 1d numpy array, in which case the flat
attribute returns the same array. ncols
and nrows
to values greater than 1, you get the axes in a 2d array, in which case the flat
attribute returns the flattened array.ncols
and nrows
to 1, you get the axes object, which doesn't have a flat
attribute.So a possible solution would be to turn your ax
object into a numpy
array everytime:
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
ax = np.array(ax)
for i, axi in enumerate(ax.flat):
axi.plot(...)
Upvotes: 0
Reputation: 39072
You can use either ax.ravel()
or ax.flatten()
. Below is a simple example
import matplotlib.pyplot as plt
fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(8, 6))
# for i, ax in enumerate(ax.flatten()):
for i, ax in enumerate(ax.ravel()):
ax.plot([1,2,3], color='k')
plt.show()
Upvotes: 0
Reputation: 18812
When you create a set of subplots having multiple rows/columns with the command fig,ax=plt.subplots()
, it returns a fig
and a list of axes ax
. The shape of ax
list is 2 dimensions (rows, cols). That is the reason why you need to flatten it to become 1 dimension when you iterate the ax
list. To get access to a specific axes, you need the row/column indices, for example, ax[r][c] is the axes on (r+1)th row/ (c+1)th column. The indices are zero based. The working code below demonstrate how to do it.
import matplotlib.pyplot as plt
import numpy as np
nrows,ncols = 3,2
figsize = [5,9]
X = np.random.rand(6)
Y = np.random.rand(6)
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=figsize)
for i, axi in enumerate(ax.flat):
axi.plot(X, Y, color='k')
rowid = i // ncols
colid = i % ncols
axi.set_title("row:"+str(rowid)+",col:"+str(colid))
# You can access the axes by row_id, col_id.
# Now let's plot on ax[row_id][col_id] of your choice
ax[0][1].plot(Y,X,color='red') # plot 2nd line in red
ax[2][0].plot(Y,X,color='green') # plot 2nd line in green
plt.tight_layout(True)
plt.show()
The output plot:
Upvotes: 0