Maelstorm
Maelstorm

Reputation: 640

Speed up my interactive matplotlib figure

I am loading a jpeg figure (large: 1300x2000) into matplotlib, drawing a grid of 50x50 squares over it and clicking on each square to color code it. However, I notice that the program lags far behind my clicks and takes up to 30 seconds to catch up if i quick about 50 squares at a reasonable speed. I am wondering if someone might be able to speed up. Below is my script which is ready to go if you copy/paste it (and have scipy,numpy, matplotlib, pillow, and tkinter)

Any advice will be welcome. I am a medical scientist so please forgive me if the code is not well explained:

import matplotlib
import matplotlib.pyplot as plt
import tkinter
import tkinter.filedialog
from matplotlib.figure import  Figure
import math, sys
import numpy as np
import scipy.io as sio
from PIL import Image
from numpy import arange, sin, pi
#from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
#matplotlib.matplotlib_fname()
import os, re

global stridesize, classnumber,x, im,fn, plt, fig, mask

classnumber = 1



def onmove(eve):
    global x,im, plt
    print(eve.ydata)
    print(eve.button)
    if (eve.ydata !=None) and (eve.xdata !=None):
        if eve.button==1:
            print(eve.button)
            xcoord = int(eve.xdata)
            ycoord = int(eve.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
           # print(eve.xdata, int(eve.ydata), stridesize)
            if(classnumber==1):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
            if(classnumber==2):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
            if(classnumber==3):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
            if(classnumber==4):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
            if(classnumber==5):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
            if(classnumber==6):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
            if(classnumber==7):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])


        if eve.button==3:
            xcoord = int(eve.xdata)
            ycoord = int(eve.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(eve.xdata, int(eve.ydata), stridesize)
            mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
        im.set_data(mask)
        fig.canvas.draw()




def onclick(event):

    if (event.ydata !=None) and (event.xdata !=None):
        global x, im, fig
        if event.button==1:
            xcoord = int(event.xdata)
            ycoord = int(event.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(event.xdata, int(event.ydata), stridesize)
            if(classnumber==1):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
            if(classnumber==2):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
            if(classnumber==3):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
            if(classnumber==4):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
            if(classnumber==5):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
            if(classnumber==6):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
            if(classnumber==7):
                mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])
            im.set_data(mask)

        if event.button==3:
            xcoord = int(event.xdata)
            ycoord = int(event.ydata)
            startX = math.floor(xcoord/stridesize)*stridesize
            startY = math.floor(ycoord/stridesize)*stridesize
            print(event.xdata, int(event.ydata), stridesize)
            mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
            im.set_data(mask)
        fig.canvas.draw()


def onpress(event):
    global classnumber, mask
    if (event.key == 'e'):
       print("YO")
       mask[:,:,:]=0;
       im.set_data(mask)
       fig.canvas.draw()

    if (event.key=='s'):
        savemask(fn)
    if (event.key=='r'):
        plt.figure();
        plt.imshow(mask);
        plt.show();
    if int(event.key) > 0 and int(event.key) <9 :
       classnumber = int(event.key)
       print(classnumber)


def onrelease(event):
    print(event.button)
 #   im.set_data(mask)




def savemask(fn):
    # matrixname =os.path.basename(filename)
    # matrixname = re.sub(r'\.jpg','',matrixname)
    pre, ext = os.path.splitext(fn)
    savename_default = os.path.basename(pre)
    options = {}
    options['defaultextension'] = ''
    options['filetypes'] = [('mat files', '.mat')]
    options['initialdir'] = ''
    options['initialfile'] = savename_default
    options['title'] = 'Save file'


    f = tkinter.filedialog.asksaveasfile(**options)
    if f is None: # asksaveasfile return `None` if aadialog closed with "cancel".
        return
    name = f.name
    sio.savemat(name,{'mask':mask},do_compression=True)
    f.close()




root = tkinter.Tk()
root.withdraw()

options = {}

options['defaultextension'] = '.jpg'

options['filetypes'] = [('Jpeg', '.jpg')]

options['initialdir'] = 'C:\\'
options['initialfile']= ''
options['parent'] = root

options['title'] = 'This is a title'


fn= tkinter.filedialog.askopenfilename(**options)


img = Image.open(fn)
x = np.asarray(img)
x.setflags(write=1)
#masksize= (x.shape[0],x.shape[1],4)
mask= np.zeros(x.shape,'uint8')
#mask[:,:,3]=0.2
fig = plt.figure()
fig.suptitle(r'Key codes: 1 = Tumour, 2 = stroma-hypocellular, 3=stroma cellular (inflammatory)' '\n4 = proteinaceous, 5= red cells, 6,7: anyother,''\nRight click: clear square''\n r:  review mask, e: erase mask, o : open mask image, s : save mask image;')


im=plt.imshow(x)
im=plt.imshow(mask,alpha=.25)
ax = plt.gca();

stridesize = 50;

plt.rcParams['keymap.save']=''
ax.set_yticks(np.arange(0, x.shape[0], stridesize));
ax.set_xticks(np.arange(0, x.shape[1], stridesize));

cid = fig.canvas.mpl_connect('button_press_event', onclick)
cod = fig.canvas.mpl_connect('key_press_event', onpress)
#cdd = fig.canvas.mpl_connect('motion_notify_event', onmove)
cdr = fig.canvas.mpl_connect('button_release_event', onrelease)

plt.grid(b=True, which='both', color='black',linestyle='-')
#
plt.show()

plt.ion()

Upvotes: 3

Views: 565

Answers (1)

Imanol Luengo
Imanol Luengo

Reputation: 15909

First, I would recommend avoiding the use of global variables at all cost. You can replace it by using a class instead. Find bellow a fully working summarized version of what your code intends to do:

import numpy as np

import matplotlib
matplotlib.use('Qt4Agg')

from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap

class ColorCode(object):

    def __init__(self, block_size=(50,50), colors=['red', 'green', 'blue'], alpha=0.3):
        self.by, self.bx = block_size # block size
        self.selected = 0 # selected color
        self.colors = colors
        self.cmap = ListedColormap(colors) # color map for labels
        self.mask = None # annotation mask
        self.alpha = alpha
        # Plots
        self.fig = plt.figure()
        self.ax = self.fig.gca()
        # Events
        self.fig.canvas.mpl_connect('button_press_event', self.on_click)
        self.fig.canvas.mpl_connect('key_press_event', self.on_key)

    def color_code(self, img):
        self.imshape = img.shape[:2]
        self.mask = np.full(img.shape[:2], -1, np.int32) # masked labels
        self.ax.imshow(img) # show image
        self.ax.imshow(np.ma.masked_where(self.mask < 0, self.mask), cmap=self.cmap,
                       alpha=self.alpha, vmin=0, vmax=len(self.colors)) # show mask
        # Run
        plt.show(block=True)
        return self.mask

    def on_click(self, event):
        if not event.inaxes or self.mask is None:
            return
        # Get corresponding coordinates
        py, px = int(event.ydata), int(event.xdata)
        cy, cx = py//self.by, px//self.bx # grid coordinates
        ymin = cy * self.by
        ymax = min((cy+1) * self.by, self.imshape[0])
        xmin = cx * self.bx
        xmax = min((cx+1) * self.bx, self.imshape[1])
        # Update mask
        if event.button == 1:
            self.mask[ymin:ymax, xmin:xmax] = self.selected
        elif event.button == 3:
            self.mask[ymin:ymax, xmin:xmax] = -1
        # Update figure
        self.ax.images[1].set_data(np.ma.masked_where(self.mask < 0, self.mask))
        self.fig.canvas.draw_idle()

    def on_key(self, event):
        ikey = int(event.key)
        if 0 <= ikey < len(self.colors):
            self.selected = ikey

The key differences with your code are:

  1. It doesn't use global variables, instead, it uses class-variables. Making it safer to run and easier to extend/modify.

  2. Instead of colouring a 3-dimensional mask for your annotations, annotations are saved as a 2-dimensional mask, where every pixel has a value in the range [1, len(colors)] indicating to which colour it belongs. Colours are then added to the plot by using ListedColormap to set a custom colormap for your plot.

  3. It draws the image and overlays on top of it a segmentation mask. Initially the mask is filled with -1, meaning it has no label. By using a numpy's masked array you can mask where mask < 0 to not show that in the plot, making the plot transparent where mask < 0 and coloured otherwise.

  4. List of possible colours are provided as a parameter to the class. It will allow you to select a colour from 0 to len(colors) with a maximum of 10 colors (as it is currently binded to numbers in the keyboard).

  5. fig.canvas.draw_idle is much better than fig.canvas.draw. The latter blocks the program until it finishes drawing.

  6. As everything is into a class, code looks much cleaner.

You can call the code as:

>>> random_image = np.random.randn(1000,2000, 3)
>>> result = ColorCode().color_code(random_image)

and result will contain the labelling mask where each pixel has a number indicating with witch color it has been tagged (-1 if none). Last, other parameters can be passed to ColorCode's constructor, such as block_size=(100,100) for different block size, alpha=0.5 for less opacity in the mask (or None as alpha=1).

Hope it works for you, or that at least you can grab some ideas from it.

Upvotes: 3

Related Questions