Reputation: 1604
How can i modify my code below to return multiple plots based on subplots dimensions
def test():
f, ax = plt.subplots(2,2)
x=np.array([1,2,3])
y=np.array([1,2,3])
for i, ax in enumerate(ax.flat):
ax.plot(x,y)
So that I can have something like that
ax1, ax2, ax3, ax4= test
Upvotes: 0
Views: 831
Reputation: 142661
You could use return list(ax.flat)
but you use the same name ax
in for
-loop to access single plot and you don't have access to original ax
. You should rename variables
import matplotlib.pyplot as plt
import numpy as np
# --- functions ---
def test():
f, all_axs = plt.subplots(2,2)
x = np.array([1,2,3])
y = np.array([1,2,3])
for i, ax in enumerate(all_axs.flat):
ax.plot(x,y)
return list(all_axs.flat)
# --- main ---
ax1, ax2, ax3, ax4 = test()
print(ax1, ax2, ax3, ax4)
plt.show()
EDIT:
I would keep it as list - and then code may get arguments rows,cols
to create more or less plots
import matplotlib.pyplot as plt
import numpy as np
# --- functions ---
def test(rows=2, cols=5):
f, all_axs = plt.subplots(rows, cols)
x = np.array([1,2,3])
y = np.array([1,2,3])
for i, ax in enumerate(all_axs.flat):
ax.plot(x,y)
return list(all_axs.flat)
# --- main ---
all_axs = test(2, 5)
print(all_axs)
plt.show()
all_axs = test(5, 10)
print(all_axs)
plt.show()
Upvotes: 1