Alex Cushley
Alex Cushley

Reputation: 95

How can I save multiple figures after they have been created in a loop?

I'm relatively new to using Python and have been re-writing some of my MatLab scripts. I have a script that creates 7 figures within a loop. After they have been made I want to save each one.I am having two problems, illustrated by the following MWE.

Problem 1 - If I save the figure in the loop it overwrites itself and I end up with only one plot (the last iteration of the for-loop), when there should be 6 plots. I can't seem to find how to specify the figure I want to save to do it after the loop, just documentation on saving the current figure which obviously isn't working.

Problem 2 - While this isn't a real problem for me, it indicates I don't understand what is happening and I'd like to learn more and improve my understanding. Before I added the line

fig3, ax3 = plt.subplots()

Figure 2 was empty. I'm not sure why I'd need to initiate a third empty figure in order for Figure 2 to show. In this MWE I don't need a third figure, and in the script I'm working on I need to initiate a 8th Figure to show Fig 7. Just comment out this line to see what I mean.

Thank you in advance for you patience and help with these problems. Here is the code for the MWE:

# Test for plotting on multiple figures for multiple angles within a for-loop

#Import packages
import numpy as np 
from math import *
from astropy.table import Table 
import matplotlib.pyplot as plt
from mpl_toolkits import mplot3d

plt.close('all')

# define the initial conditions
x = 0               # initial x position
y = 0               # initial y position
z = 0               # initial z position

v = 30 
g = -9.8 

lst = [ 20, 30, 40, 45, 50, 60] # launch angles    
alpha= np.array(lst) 

def size(arr):
    if len(arr.shape) == 1:
        return arr.shape[0], 1
    return arr.shape

[nn,mm] = size(alpha)   #https://appdividend.com/2022/02/02/python-array-length/

#create plots:
ax1 = plt.axes(projection='3d')#fig1, ax1 = plt.subplots()
fig2, ax2 = plt.subplots(1,1)
fig3, ax3 = plt.subplots()

for kk in range(nn):#= 1:nn
    theta = alpha[kk]*pi/180.0

    # reset x=y=z=0 for next iteration since they are used to initialize the tables
    x = 0               # initial x position
    y = 0               # initial y position
    z = 0               # initial z position
    t = 0               # starting at time 0
 
    x3table = [x/3]
    y3table = [y/3]
    z3table = [z/3]

    h = 0.0100 
    tf = 200
    N=ceil(tf/h)
    for i in range(N):#for i = 1:N
        t = t + h #t(i+1) = t(i) + h

        # update position
        x = v*t*cos(theta)
        y = v*t*sin(theta) - ((0.5 * g) * (t ** 2))
        z = 0.1*x

        """ appends selected data for ability to plot"""
        x3table += [x/3]
        y3table += [y/3]
        z3table += [z/3]
    x3table, y3table = zip(*sorted(zip(x3table, y3table)))
    ###############################################################################################################################
    ## figure 1 - 3 dimensional position
    # PLOT 1
    ###############################################################################################################################
    #fig1 = plt.figure(1) 

    ax1.plot3D(x3table ,z3table ,y3table)
    plt.title('Fig1 Test')
    plt.savefig('Subtest_fig1.png', bbox_inches='tight')
    #if k==nn
            #plt.savefig('Subtest_fig1.png', bbox_inches='tight')
    plt.clf()
    ###############################################################################################################################
    ## figure 2 - 2 dimensional position
    # PLOT 2
    ###############################################################################################################################
    #fig2 = plt.figure(2)
    ax2.plot(x3table ,y3table)#,'linewidth',2)
    plt.title('Fig2 Test')
    plt.savefig('Subtest_fig2.png', bbox_inches='tight')
    #if k==nn
            #plt.savefig('Subtest_fig1.png', bbox_inches='tight')
    plt.clf()

plt.show()

The figures should appear like as shown in Subtest_fig1.png and Subtest_fig2.png (6 plots on each figure). Subtest_fig1.png Subtest_fig2.png

Upvotes: 2

Views: 2957

Answers (1)

Guinther Kovalski
Guinther Kovalski

Reputation: 1909

probably your problem is that you are saving it with the exactly same name, try this:

plt.savefig(str(kk)+'Subtest_fig2.png', bbox_inches='tight')

also, this is an example of how subplots works.

import numpy as np
import matplotlib.pyplot as plt

def rotate(px,py,teta):
    newX = px*np.cos(teta)+py*np.sin(teta)
    newY = px*np.sin(teta)-py*np.cos(teta)
    return newX,newY

fig, axs = plt.subplots(3, 3,figsize = (15,15))

x = np.cos(np.linspace(-3.24,3.24,100))
y = 3*np.sin(np.linspace(-3.24,3.24,100))

teta = 0
for i in range(3):
    for j in range(3):
        teta += (360/10)*np.pi/180
        px,py = rotate(px,py,teta)
        
        axs[j, i].scatter(px,py)
        axs[j, i].axis(xmin=-3,xmax=3)
        axs[j, i].axis(ymin=-3,ymax=3)
    
plt.savefig("test.svg",dpi=2000)

enter image description here

if you want to make a single figure multiple time, you don't need subplots. just use plot, like:

import numpy as np
import matplotlib.pyplot as plt

def rotate(px,py,teta):
    newX = px*np.cos(teta)+py*np.sin(teta)
    newY = px*np.sin(teta)-py*np.cos(teta)
    return newX,newY

#fig, axs = plt.subplots(3, 3,figsize = (15,15))

x = np.cos(np.linspace(-3.24,3.24,100))
y = 3*np.sin(np.linspace(-3.24,3.24,100))

teta = 0
for i in range(3):
    for j in range(3):
        teta += (360/10)*np.pi/180
        px,py = rotate(px,py,teta)
        
        plt.figure(figsize=(5,5))
        plt.cla()
        plt.xlim(-3,3)
        plt.ylim(-3,3)
        plt.scatter(px,py)

        plt.savefig(str(i)+str(j)+"test.png",dpi=300)

enter image description here enter image description here

...

Upvotes: 2

Related Questions