carl
carl

Reputation: 50554

pyplot: really slow creating heatmaps

I have a loop that executes the body about 200 times. In each loop iteration, it does a sophisticated calculation, and then as debugging, I wish to produce a heatmap of a NxM matrix. But, generating this heatmap is unbearably slow and significantly slow downs an already slow algorithm.

My code is along the lines:

import numpy
import matplotlib.pyplot as plt
for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))

The matrix, from numpy, is not huge --- 300 x 600 of doubles. Even if I do not save the figure and instead update an on-screen plot, it's even slower.

Surely I must be abusing pyplot. (Matlab can do this, no problem.) How do I speed this up?

Upvotes: 5

Views: 3000

Answers (2)

Steve Tjoa
Steve Tjoa

Reputation: 61064

I think this is a bit faster:

import matplotlib.pyplot as plt
from matplotlib import cm
fig = plt.figure()
ax = fig.add_axes([0.1,0.1,0.8,0.8])
for i in range(200):
    matrix = complex_calculation()
    ax.imshow(matrix, cmap=cm.gray)
    fig.savefig("frame{0}.png".format(i))

plt.imshow calls gca which calls gcf which checks to see if there is a figure; if not, it creates one. By manually instantiating the figure first, you do not need to do all that.

Upvotes: 3

unutbu
unutbu

Reputation: 880259

Try putting plt.clf() in the loop to clear the current figure:

for i in range(200):
    matrix = complex_calculation()
    plt.set_cmap("gray")
    plt.imshow(matrix)
    plt.savefig("frame{0}.png".format(i))
    plt.clf()

If you don't do this, the loop slows down as the machine struggles to allocate more and more memory for the figure.

Upvotes: 5

Related Questions