Reputation: 8593
I have some numpy image arrays, all of the same shape (say (64, 64, 3)). I want to plot them in a grid using pyplot.subplot()
, but when I do, I get unwanted spacing between images, even when I use pyplot.subplots_adjust(hspace=0, wspace=0)
. Below is an example piece of code.
from matplotlib import pyplot
import numpy
def create_dummy_images():
"""
Creates images, each of shape (64, 64, 3) and of dtype 8-bit unsigned integer.
:return: 4 images in a list.
"""
saturated_channel = numpy.ones((64, 64), dtype=numpy.uint8) * 255
zero_channel = numpy.zeros((64, 64), dtype=numpy.uint8)
red = numpy.array([saturated_channel, zero_channel, zero_channel]).transpose(1, 2, 0)
green = numpy.array([zero_channel, saturated_channel, zero_channel]).transpose(1, 2, 0)
blue = numpy.array([zero_channel, zero_channel, saturated_channel]).transpose(1, 2, 0)
random = numpy.random.randint(0, 256, (64, 64, 3))
return [red, green, blue, random]
if __name__ == "__main__":
images = create_dummy_images()
for i, image in enumerate(images):
pyplot.subplot(2, 2, i + 1)
pyplot.axis("off")
pyplot.imshow(image)
pyplot.subplots_adjust(hspace=0, wspace=0)
pyplot.show()
Below is the output.
As you can see, there is unwanted vertical space between those images. One way of circumventing this problem is to carefully hand-pick the right size for the figure, for example I use matplotlib.rcParams['figure.figsize'] = (_, _)
in Jupyter Notebook. However, the number of images I usually want to plot varies between each time I plot them, and hand-picking the right figure size each time is extremely inconvenient (especially because I can't work out exactly what the size means in Matplotlib). So, is there a way that Matplotlib can automatically work out what size the figure should be, given my requirement that all my (64 x 64) images need to be flush next to each other? (Or, for that matter, a specified distance next to each other?)
Upvotes: 9
Views: 11447
Reputation: 40737
NOTE: correct answer is reported in the update below the original answer.
Create your subplots first, then plot in them. I did it on one line here for simplicity sake
images = create_dummy_images()
fig, axs = pyplot.subplots(nrows=1, ncols=4, gridspec_kw={'wspace':0, 'hspace':0},
squeeze=True)
for i, image in enumerate(images):
axs[i].axis("off")
axs[i].imshow(image)
UPDATE:
Nevermind, the problem was not with your subplot definition, but with imshow()
which distorts your axes after you've set them up correctly.
The solution is to use aspect='auto'
in the call to imshow()
so that the pictures fills the axes without changing them. If you want to have square axes, you need to create a picture with the appropriate width/height ratio:
pyplot.figure(figsize=(5,5))
images = create_dummy_images()
for i, image in enumerate(images):
pyplot.subplot(2, 2, i + 1)
pyplot.axis("off")
pyplot.imshow(image, aspect='auto')
pyplot.subplots_adjust(hspace=0, wspace=0)
pyplot.show()
Upvotes: 14