feathy123
feathy123

Reputation: 13

Subplots via iteration through a dictionary of arrays in python with matplotlib

I have a dictionary of arrays (40 entries, shape of each array = (5001,)) and I want to be able to iterate through and create subplots in a 5 * 8 grid. So far I can only work out how to make a 40*1 grid:

fig, axes = plt.subplots(40,sharex=True,sharey=True,figsize=(3,30))
for i, (key, value) in enumerate(new_dict.items()):
    print(i, key, value)
    axes[i].plot(value)
    axes[i].set(title=key.upper(), xlabel='ns')
plt.show()

40*1 grid

Something like this will put the last plot in each graph of the 5*8 grid:

fig, axes = plt.subplots(ncols=5,nrows=8,sharex=True,sharey=True,figsize=(10,30))
axes = axes.flatten()
for i, ax in enumerate(axes.flatten()):
    for a, (key, value) in enumerate(new_dict.items()):
        print(a, key, value)
    ax.plot(value)
    ax.set(title=key.upper(), xlabel='ns')
plt.show()

Single (final) plot in each graph

And switching the for loops puts all the plots overlayed in each graph of the 5*8 grid:

fig, axes = plt.subplots(ncols=5,nrows=8,sharex=True,sharey=True,figsize=(10,30))
for i, ax in enumerate(axes.flatten()):
    for a, (key, value) in enumerate(new_dict.items()):   
        print(a, key, value)
        ax.plot(value)
        ax.set(title=key.upper(), xlabel='ns')
plt.show()

All plots in each graph

I cannot for the life of me work out how to put a different plot in each graph. Any help would be greatly appreciated! I feel like I'm missing something really obvious here. Many thanks :-)

EDIT: I've realised that the position of the for loops doesn't matter and it is the tabbing of the ax.plot... that changes whether the last plot fills the graph or all of the plots.

Upvotes: 1

Views: 862

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150735

You can create a subplot grid with given size:

# sample data
new_dict = {str(i): np.linspace(0,1,100)**i for i in range(40)}

fig, axes = plt.subplots(5,8,sharex=True,sharey=True,figsize=(8,5))

#flatten axes so you can access axes[i]
axes = axes.ravel()

for i, (key, value) in enumerate(new_dict.items()):
    print(i, key, value)
    axes[i].plot(value)
    axes[i].set(title=key.upper(), xlabel='ns')
plt.show()

You would get something like this:

enter image description here

Upvotes: 2

Related Questions