Srivatsan
Srivatsan

Reputation: 9363

Matplotlib subplots inside a for loop

I have a function that takes in as input, 3 arrays and a constant value.

Inside the function I am giving 10 different arrays with conditions and trying to plot them in 10 different subplots.

def ra_vs_dec(alpha,delta,zphot,mlim):
    zmin = [0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2]
    zmax = [0.4,0.5,0.6,0.7,0.8,0.9,1.0,1.1,1.2,1.3]
    plot_no = [1,2,3,4,5,6,7,8,9,10] # THESE ARE THE SUBPLOT NUMBERS

    for a,b,c in zip(zmin,zmax,plot_no):
        ra = alpha[(data_m200>mlim)*(data_z>a)*(data_z<b)] # RA FOR ZCOSMO
        dec = delta[(data_m200>mlim)*(data_z>a)*(data_z<b)] # DEC FOR ZCOSMO
        ra_zphot = alpha[(data_m200>mlim)*(zphot>a)*(zphot<b)] # RA FOR ZPHOT
        dec_zphot = delta[(data_m200>mlim)*(zphot>a)*(zphot<b)] # DEC FOR ZPHOT

        fig = plt.figure()
        ax = fig.add_subplot(2,5,c)
        ax.scatter(ra,dec,color='red',s=5.0,label=''+str(a)+'<zcosmo<'+str(b)+'')
        ax.scatter(ra_zphot,dec_zphot,color='blue',s=5.0,label=''+str(a)+'<zphot<'+str(b)+'')
        ax.legend(loc='best',scatterpoints=2)

    fig.show()

However, when I run the above code, I am getting only the final subplot, i.e the 10th subplot. What I am doing wrong here?

I would like to see all the 10 subplots.

Upvotes: 0

Views: 3432

Answers (1)

tmdavison
tmdavison

Reputation: 69116

move the creation of the figure outside of the loop. By having that inside the loop, you are creating 10 separate figures, and then adding only one subplot to each one. As you then use fig.show(), its only showing you the figure created in the final iteration.

fig = plt.figure()

for a,b,c in zip(zmin,zmax,plot_no):
    ra = alpha[(data_m200>mlim)*(data_z>a)*(data_z<b)] # RA FOR ZCOSMO
    dec = delta[(data_m200>mlim)*(data_z>a)*(data_z<b)] # DEC FOR ZCOSMO
    ra_zphot = alpha[(data_m200>mlim)*(zphot>a)*(zphot<b)] # RA FOR ZPHOT
    dec_zphot = delta[(data_m200>mlim)*(zphot>a)*(zphot<b)] # DEC FOR ZPHOT

    ax = fig.add_subplot(2,5,c)
    ax.scatter(ra,dec,color='red',s=5.0,label=''+str(a)+'<zcosmo<'+str(b)+'')
    ax.scatter(ra_zphot,dec_zphot,color='blue',s=5.0,label=''+str(a)+'<zphot<'+str(b)+'')
    ax.legend(loc='best',scatterpoints=2)

fig.show()

Upvotes: 2

Related Questions