shenglih
shenglih

Reputation: 939

save subplots after .imshow()

Here are my visualization codes:

f, ax = plt.subplots(1, 2)
for i, img in enumerate([img1, img2]):    
    grads = # my visualization codes
# visualize grads as heatmap
ax[i].imshow(grads, cmap='jet')

How could I save whatever was shown using imshow here? Any advice is greatly appreciated!

Upvotes: 3

Views: 8758

Answers (1)

mostlyoxygen
mostlyoxygen

Reputation: 991

Saving the whole figure is simple, just use the savefig function:

f.savefig('filename.png')

There are a number of file formats you can save to, and these are usually inferred correctly from the extension of the filename. See the documentation for more information.

The savefig function takes an argument bbox_inches, which defines the area of the figure to be saved. To save an individual subplot to file you can use the bounding box of the subplot's Axes object to calculate the appropriate value.

Putting it all together your code would look something like this:

f, ax = plt.subplots(1, 2)
for i, img in enumerate([img1, img2]):    
    grads = # my visualization codes
    # visualize grads as heatmap
    ax[i].imshow(grads, cmap='jet')

    # Save the subplot.
    bbox = ax[i].get_tightbbox(f.canvas.get_renderer())
    f.savefig("subplot{}.png".format(i),
              bbox_inches=bbox.transformed(f.dpi_scale_trans.inverted()))

# Save the whole figure.
f.savefig("whole_figure.png")

Upvotes: 5

Related Questions