Reputation: 567
I have variables x
and y
def function(a,b):
x = x[(x>a)*(x<b)]
y = y[(y<a)*(y>b)]
# perform some fitting routine using curve_fit on x and y
fig = plt.figure()
ax = fig.add_subplot(111)
phist,xedge,yedge,img = ax.hist2d(x,y,bins=20,norm=LogNorm())
im = ax.imshow(phist,cmap=plt.cm.jet,norm=LogNorm(),aspect='auto')
fig.colorbar(im,ax=ax)
fig.show()
All works fine. But I have 6 pairs of different input parameters a
and b
. I would like to somehow call function(a,b)
using a loop and plot the six different x
and y
(corresponding to the 6 input pairs) as 6 subplots.
like we do
ax1 = fig.add_subplot(231) # x vs y for a1,b1
ax2 = fig.add_subplot(232) # x vs y for a2,b2
....
ax6 = fig.add_subplot(236) # x vs y for a6,b6
I would like to get an idea of how to proceed to get the final subplot!
I know that it can be done manually by specifying different variables, like x1
and y1
for the first input pair a
and b
and so on for the other 6 pairs (x2,y2...,x6,y6)
. But it will be a very lengthy and confusing code.
Upvotes: 2
Views: 3703
Reputation: 10298
Use plt.subplots
instead of plt.subplot
(note the "s" at the end). fig, axs = plt.subplots(2, 3)
will create a figure with 2x3 group of subplots, where fig
is the figure, and axs
is a 2x3 numpy array where each element is the axis object corresponding to the axis in the same position in the figure (so axs[1, 2]
is the bottom-right axis).
You can then either use a pair of loops to loop over each row then each axis in that row:
fig, axs = plt.subplots(2, 3)
for i, row in enumerate(axs):
for j, ax in enumerate(row):
ax.imshow(foo[i, j])
fig.show()
Or you can use ravel
to flatten the rows and whatever you want to get the data from:
fig, axs = plt.subplots(2, 3)
foor = foo.ravel()
for i, ax in enumerate(axs.ravel()):
ax.imshow(foor[i])
fig.show()
Note that ravel
is a view, not a copy, so this won't take any additional memory.
Upvotes: 3
Reputation: 19750
The key is using the three parameter form of subplot:
import matplotlib.pyplot as plt
# Build a list of pairs for a, b
ab = zip(range(6), range(6))
#iterate through them
for i, (a, b) in enumerate(ab):
plt.subplot(2, 3, i+1)
#function(a, b)
plt.plot(a, b)
plt.show()
You'll just have to take the call to figure
out of the function.
Upvotes: 3