Roger Steinberg
Roger Steinberg

Reputation: 1604

How to return multiple plots in a function

How can i modify my code below to return multiple plots based on subplots dimensions

def test():
  f, ax = plt.subplots(2,2)
  x=np.array([1,2,3])
  y=np.array([1,2,3])

  for i, ax in enumerate(ax.flat):
    ax.plot(x,y)

So that I can have something like that

ax1, ax2, ax3, ax4= test

Upvotes: 0

Views: 831

Answers (1)

furas
furas

Reputation: 142661

You could use return list(ax.flat) but you use the same name ax in for-loop to access single plot and you don't have access to original ax. You should rename variables

import matplotlib.pyplot as plt
import numpy as np

# --- functions ---

def test():
    f, all_axs = plt.subplots(2,2)

    x = np.array([1,2,3])
    y = np.array([1,2,3])

    for i, ax in enumerate(all_axs.flat):
        ax.plot(x,y)
      
    return list(all_axs.flat)

# --- main ---

ax1, ax2, ax3, ax4 = test()

print(ax1, ax2, ax3, ax4)

plt.show()

EDIT:

I would keep it as list - and then code may get arguments rows,cols to create more or less plots

import matplotlib.pyplot as plt
import numpy as np

# --- functions ---

def test(rows=2, cols=5):
    f, all_axs = plt.subplots(rows, cols)

    x = np.array([1,2,3])
    y = np.array([1,2,3])

    for i, ax in enumerate(all_axs.flat):
        ax.plot(x,y)
      
    return list(all_axs.flat)

# --- main ---

all_axs = test(2, 5)
print(all_axs)

plt.show()

all_axs = test(5, 10)
print(all_axs)

plt.show()

Upvotes: 1

Related Questions