Reputation: 640
I am loading a jpeg figure (large: 1300x2000) into matplotlib, drawing a grid of 50x50 squares over it and clicking on each square to color code it. However, I notice that the program lags far behind my clicks and takes up to 30 seconds to catch up if i quick about 50 squares at a reasonable speed. I am wondering if someone might be able to speed up. Below is my script which is ready to go if you copy/paste it (and have scipy,numpy, matplotlib, pillow, and tkinter)
Any advice will be welcome. I am a medical scientist so please forgive me if the code is not well explained:
import matplotlib
import matplotlib.pyplot as plt
import tkinter
import tkinter.filedialog
from matplotlib.figure import Figure
import math, sys
import numpy as np
import scipy.io as sio
from PIL import Image
from numpy import arange, sin, pi
#from matplotlib.backends.backend_qt5agg import FigureCanvasQTAgg as FigureCanvas
#matplotlib.matplotlib_fname()
import os, re
global stridesize, classnumber,x, im,fn, plt, fig, mask
classnumber = 1
def onmove(eve):
global x,im, plt
print(eve.ydata)
print(eve.button)
if (eve.ydata !=None) and (eve.xdata !=None):
if eve.button==1:
print(eve.button)
xcoord = int(eve.xdata)
ycoord = int(eve.ydata)
startX = math.floor(xcoord/stridesize)*stridesize
startY = math.floor(ycoord/stridesize)*stridesize
# print(eve.xdata, int(eve.ydata), stridesize)
if(classnumber==1):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
if(classnumber==2):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
if(classnumber==3):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
if(classnumber==4):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
if(classnumber==5):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
if(classnumber==6):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
if(classnumber==7):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])
if eve.button==3:
xcoord = int(eve.xdata)
ycoord = int(eve.ydata)
startX = math.floor(xcoord/stridesize)*stridesize
startY = math.floor(ycoord/stridesize)*stridesize
print(eve.xdata, int(eve.ydata), stridesize)
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
im.set_data(mask)
fig.canvas.draw()
def onclick(event):
if (event.ydata !=None) and (event.xdata !=None):
global x, im, fig
if event.button==1:
xcoord = int(event.xdata)
ycoord = int(event.ydata)
startX = math.floor(xcoord/stridesize)*stridesize
startY = math.floor(ycoord/stridesize)*stridesize
print(event.xdata, int(event.ydata), stridesize)
if(classnumber==1):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,0])
if(classnumber==2):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,0])
if(classnumber==3):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,255])
if(classnumber==4):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,0,255])
if(classnumber==5):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([255,255,0])
if(classnumber==6):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,255,255])
if(classnumber==7):
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([100,255,50])
im.set_data(mask)
if event.button==3:
xcoord = int(event.xdata)
ycoord = int(event.ydata)
startX = math.floor(xcoord/stridesize)*stridesize
startY = math.floor(ycoord/stridesize)*stridesize
print(event.xdata, int(event.ydata), stridesize)
mask[startY:startY+stridesize,startX:startX+stridesize,:]=np.array([0,0,0])
im.set_data(mask)
fig.canvas.draw()
def onpress(event):
global classnumber, mask
if (event.key == 'e'):
print("YO")
mask[:,:,:]=0;
im.set_data(mask)
fig.canvas.draw()
if (event.key=='s'):
savemask(fn)
if (event.key=='r'):
plt.figure();
plt.imshow(mask);
plt.show();
if int(event.key) > 0 and int(event.key) <9 :
classnumber = int(event.key)
print(classnumber)
def onrelease(event):
print(event.button)
# im.set_data(mask)
def savemask(fn):
# matrixname =os.path.basename(filename)
# matrixname = re.sub(r'\.jpg','',matrixname)
pre, ext = os.path.splitext(fn)
savename_default = os.path.basename(pre)
options = {}
options['defaultextension'] = ''
options['filetypes'] = [('mat files', '.mat')]
options['initialdir'] = ''
options['initialfile'] = savename_default
options['title'] = 'Save file'
f = tkinter.filedialog.asksaveasfile(**options)
if f is None: # asksaveasfile return `None` if aadialog closed with "cancel".
return
name = f.name
sio.savemat(name,{'mask':mask},do_compression=True)
f.close()
root = tkinter.Tk()
root.withdraw()
options = {}
options['defaultextension'] = '.jpg'
options['filetypes'] = [('Jpeg', '.jpg')]
options['initialdir'] = 'C:\\'
options['initialfile']= ''
options['parent'] = root
options['title'] = 'This is a title'
fn= tkinter.filedialog.askopenfilename(**options)
img = Image.open(fn)
x = np.asarray(img)
x.setflags(write=1)
#masksize= (x.shape[0],x.shape[1],4)
mask= np.zeros(x.shape,'uint8')
#mask[:,:,3]=0.2
fig = plt.figure()
fig.suptitle(r'Key codes: 1 = Tumour, 2 = stroma-hypocellular, 3=stroma cellular (inflammatory)' '\n4 = proteinaceous, 5= red cells, 6,7: anyother,''\nRight click: clear square''\n r: review mask, e: erase mask, o : open mask image, s : save mask image;')
im=plt.imshow(x)
im=plt.imshow(mask,alpha=.25)
ax = plt.gca();
stridesize = 50;
plt.rcParams['keymap.save']=''
ax.set_yticks(np.arange(0, x.shape[0], stridesize));
ax.set_xticks(np.arange(0, x.shape[1], stridesize));
cid = fig.canvas.mpl_connect('button_press_event', onclick)
cod = fig.canvas.mpl_connect('key_press_event', onpress)
#cdd = fig.canvas.mpl_connect('motion_notify_event', onmove)
cdr = fig.canvas.mpl_connect('button_release_event', onrelease)
plt.grid(b=True, which='both', color='black',linestyle='-')
#
plt.show()
plt.ion()
Upvotes: 3
Views: 565
Reputation: 15909
First, I would recommend avoiding the use of global
variables at all cost. You can replace it by using a class
instead. Find bellow a fully working summarized version of what your code intends to do:
import numpy as np
import matplotlib
matplotlib.use('Qt4Agg')
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap
class ColorCode(object):
def __init__(self, block_size=(50,50), colors=['red', 'green', 'blue'], alpha=0.3):
self.by, self.bx = block_size # block size
self.selected = 0 # selected color
self.colors = colors
self.cmap = ListedColormap(colors) # color map for labels
self.mask = None # annotation mask
self.alpha = alpha
# Plots
self.fig = plt.figure()
self.ax = self.fig.gca()
# Events
self.fig.canvas.mpl_connect('button_press_event', self.on_click)
self.fig.canvas.mpl_connect('key_press_event', self.on_key)
def color_code(self, img):
self.imshape = img.shape[:2]
self.mask = np.full(img.shape[:2], -1, np.int32) # masked labels
self.ax.imshow(img) # show image
self.ax.imshow(np.ma.masked_where(self.mask < 0, self.mask), cmap=self.cmap,
alpha=self.alpha, vmin=0, vmax=len(self.colors)) # show mask
# Run
plt.show(block=True)
return self.mask
def on_click(self, event):
if not event.inaxes or self.mask is None:
return
# Get corresponding coordinates
py, px = int(event.ydata), int(event.xdata)
cy, cx = py//self.by, px//self.bx # grid coordinates
ymin = cy * self.by
ymax = min((cy+1) * self.by, self.imshape[0])
xmin = cx * self.bx
xmax = min((cx+1) * self.bx, self.imshape[1])
# Update mask
if event.button == 1:
self.mask[ymin:ymax, xmin:xmax] = self.selected
elif event.button == 3:
self.mask[ymin:ymax, xmin:xmax] = -1
# Update figure
self.ax.images[1].set_data(np.ma.masked_where(self.mask < 0, self.mask))
self.fig.canvas.draw_idle()
def on_key(self, event):
ikey = int(event.key)
if 0 <= ikey < len(self.colors):
self.selected = ikey
The key differences with your code are:
It doesn't use global variables, instead, it uses class-variables. Making it safer to run and easier to extend/modify.
Instead of colouring a 3-dimensional mask
for your annotations, annotations are saved as a 2-dimensional mask
, where every pixel has a value in the range [1, len(colors)]
indicating to which colour it belongs. Colours are then added to the plot by using ListedColormap
to set a custom colormap for your plot.
It draws the image and overlays on top of it a segmentation mask. Initially the mask is filled with -1
, meaning it has no label. By using a numpy's masked array you can mask where mask < 0
to not show that in the plot, making the plot transparent where mask < 0
and coloured otherwise.
List of possible colours are provided as a parameter to the class. It will allow you to select a colour from 0 to len(colors)
with a maximum of 10 colors (as it is currently binded to numbers in the keyboard).
fig.canvas.draw_idle
is much better than fig.canvas.draw
. The latter blocks the program until it finishes drawing.
As everything is into a class, code looks much cleaner.
You can call the code as:
>>> random_image = np.random.randn(1000,2000, 3)
>>> result = ColorCode().color_code(random_image)
and result
will contain the labelling mask
where each pixel has a number indicating with witch color it has been tagged (-1 if none). Last, other parameters can be passed to ColorCode
's constructor, such as block_size=(100,100)
for different block size, alpha=0.5
for less opacity in the mask (or None as alpha=1
).
Hope it works for you, or that at least you can grab some ideas from it.
Upvotes: 3