cadams
cadams

Reputation: 1415

Eliminate white space between subplots in a matplotlib figure

I am trying to create a nice plot which joins a 4x4 grid of subplots (placed with gridspec, each subplot is 8x8 pixels ). I constantly struggle getting the spacing between the plots to match what I am trying to tell it to do. I imagine the problem is arising from plotting a color bar on the right side of the figure, and adjusting the location of the plots in the figure to accommodate. However, it appears that this issue crops up even without the color bar included, which has further confused me. It may also have to do with the margin spacing. The images shown below are produced by the associated code. As you can see, I am trying to get the space between the plots to go to zero, but it doesn't seem to be working. Can anyone advise?

fig = plt.figure('W Heat Map', (18., 15.))
gs = gridspec.GridSpec(4,4)
gs.update(wspace=0., hspace=0.)
for index in indices:
    loc = (i,j) #determined by the code
    ax = plt.subplot(gs[loc])
    c = ax.pcolor(physHeatArr[index,:,:], vmin=0, vmax=1500)
    # take off axes 
    ax.axis('off')
    ax.set_aspect('equal')
fig.subplots_adjust(right=0.8,top=0.9,bottom=0.1)
cbar_ax = heatFig.add_axes([0.85, 0.15, 0.05, 0.7])
cbar = heatFig.colorbar(c, cax=cbar_ax)
cbar_ax.tick_params(labelsize=16)
fig.savefig("heatMap.jpg")

Rectangle figure

Similarly, in making a square figure without the color bar:

fig = plt.figure('W Heat Map', (15., 15.))
gs = gridspec.GridSpec(4,4)
gs.update(wspace=0., hspace=0.)
for index in indices:
    loc = (i,j) #determined by the code
    ax = plt.subplot(gs[loc])
    c = ax.pcolor(physHeatArr[index,:,:], vmin=0, vmax=400, cmap=plt.get_cmap("Reds_r"))
    # take off axes 
    ax.axis('off')
    ax.set_aspect('equal')
fig.savefig("heatMap.jpg")

Square figure

Upvotes: 4

Views: 7638

Answers (1)

ImportanceOfBeingErnest
ImportanceOfBeingErnest

Reputation: 339660

When the axes aspect ratio is set to not automatically adjust (e.g. using set_aspect("equal") or a numeric aspect, or in general using imshow), there might be some white space between the subplots, even if wspace and hspaceare set to 0. In order to eliminate white space between figures, you may have a look at the following questions

  1. How to remove gaps between *images* in matplotlib?
  2. How to combine gridspec with plt.subplots() to eliminate space between rows of subplots
  3. How to remove the space between subplots in matplotlib.pyplot?

You may first consider this answer to the first question, where the solution is to build a single array out of the individual arrays and then plot this single array using pcolor, pcolormesh or imshow. This makes it especially comfortable to add a colorbar later on.

Otherwise consider setting the figuresize and subplot parameters such that no whitespae will remain. Formulas for that calculation are found in this answer to the second question.

An adapted version with colorbar would look like this:

import matplotlib.pyplot as plt
import matplotlib.colors
import matplotlib.cm
import numpy as np

image = np.random.rand(16,8,8)
aspect = 1.
n = 4 # number of rows
m = 4 # numberof columns
bottom = 0.1; left=0.05
top=1.-bottom; right = 1.-0.18
fisasp = (1-bottom-(1-top))/float( 1-left-(1-right) )
#widthspace, relative to subplot size
wspace=0  # set to zero for no spacing
hspace=wspace/float(aspect)
#fix the figure height
figheight= 4 # inch
figwidth = (m + (m-1)*wspace)/float((n+(n-1)*hspace)*aspect)*figheight*fisasp

fig, axes = plt.subplots(nrows=n, ncols=m, figsize=(figwidth, figheight))
plt.subplots_adjust(top=top, bottom=bottom, left=left, right=right, 
                    wspace=wspace, hspace=hspace)
#use a normalization to make sure the colormapping is the same for all subplots
norm=matplotlib.colors.Normalize(vmin=0, vmax=1 )
for i, ax in enumerate(axes.flatten()):
    ax.imshow(image[i, :,:], cmap = "RdBu", norm=norm)
    ax.axis('off')
# use a scalarmappable derived from the norm instance to create colorbar
sm = matplotlib.cm.ScalarMappable(cmap="RdBu", norm=norm)
sm.set_array([])
cax = fig.add_axes([right+0.035, bottom, 0.035, top-bottom])
fig.colorbar(sm, cax=cax)

plt.show()

enter image description here

Upvotes: 7

Related Questions