Tilman
Tilman

Reputation: 13

Python Using pyplot slider with subplots

I am quite new to Python, so please excuse if this is a stupid beginner's error. However I am struggling with it for quite some time. I want to create a figure with n x m subplots, each subplot being np.array of shape [1024,264,264]. As I am looking for differences occuring in the stack along the 0-dimension I want to use a slider to explore all stacks in my figure simultaneously. The slider instance works nicely for a figure with one subplot but I can't bring them all to work. That's the code I am using:

import os
from matplotlib import pyplot as plt
import numpy as np

import glob
import h5py
#Define the xy size of the mapped array
xsize=3
ysize=3

lengthh5=9
readlist=[]
for i in range (0,lengthh5):
    npraw=np.random.rand(200,50,50)
    readlist.append (npraw)

''' Slider visualization'''
from matplotlib.widgets import Slider
fig=plt.figure()
for k in range (0,lengthh5):
    ax=fig.add_subplot(xsize,ysize,k)        
    frame = 10
    l = ax.imshow(readlist[k][frame,:,:]) 
    plt.axis('off')
           sframe = Slider(fig.add_subplot(50,1,50), 'Frame', 0, len(readlist[0])-1, valinit=0)
    def update(val):
        frame = np.around(sframe.val)
        l.set_data(readlist[k][frame,:,:])


sframe.on_changed(update)

plt.show()

For this particular case I stripped it down to a 3x3 array for my figure and just create randmom (smaller) arrays. The slider is interestinly only operable on the second last subplot. However I have no real idea how to link it to all subplots simulatenously. Perhaps someone has an idea how to do this. Thanks a lot in advance,

Tilman

Upvotes: 1

Views: 3843

Answers (1)

Ed Smith
Ed Smith

Reputation: 13196

You need to store each imshow AxesImage in a list and inside update, loop over all of them and update each based on the slider,

import os
from matplotlib import pyplot as plt
from matplotlib.widgets import Slider
import numpy as np

import glob
import h5py
#Define the xy size of the mapped array
xsize=3
ysize=3

lengthh5=9
readlist=[]
for i in range (0,lengthh5):
    npraw=np.random.rand(200,50,50)
    readlist.append (npraw)

fig=plt.figure()
ls = []
for k in range (0,lengthh5):
    ax=fig.add_subplot(xsize,ysize,k)        
    frame = 10
    l = ax.imshow(readlist[k][frame,:,:]) 
    ls.append(l)
    plt.axis('off')

sframe = Slider(fig.add_subplot(50,1,50), 'Frame', 
                0, len(readlist[0])-1, valinit=0)

def update(val):
    frame = np.around(sframe.val)
    for k, l in enumerate(ls):
        l.set_data(readlist[k][frame,:,:])

sframe.on_changed(update)
plt.show()

Upvotes: 1

Related Questions