Reputation: 2570
I'm trying to plot 10 samples from the MNIST dataset. One of each digit. Here's the code:
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data
for i in range(10):
im_idx = np.argwhere(y == i)[0]
print(im_idx)
plottable_image = np.reshape(X[im_idx], (28, 28))
plt.imshow(plottable_image, cmap='gray_r')
plt.subplot(2, 5, i + 1)
plt.plot()
For some reason, the zero digit is being skipped in the plot.
Why?
Upvotes: 2
Views: 5807
Reputation: 39042
Ok, I got it. The problem was that you were defining the subplot after plotting imshow
. So your first subplot was overwritten by the second one. To make your code work, just swap the order of your two commands as following. Also, I don't see why you are using plt.plot()
at the end.
plt.subplot(2, 5, i + 1) # <-- You have put this command after imshow
plt.imshow(plottable_image, cmap='gray_r')
Here is another alternate for your knowledge:
fig = plt.figure()
for i in range(10):
im_idx = np.argwhere(y == i)[0]
plottable_image = np.reshape(X[im_idx], (28, 28))
ax = fig.add_subplot(2, 5, i+1)
ax.imshow(plottable_image, cmap='gray_r')
You can also further shorten Scott's code (posted below) by using the following:
fig, ax = plt.subplots(2,5)
for i, ax in enumerate(ax.flatten()):
im_idx = np.argwhere(y == i)[0]
plottable_image = np.reshape(X[im_idx], (28, 28))
ax.imshow(plottable_image, cmap='gray_r')
Upvotes: 4
Reputation: 153460
Try this:
import sklearn
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn import datasets
mnist = datasets.fetch_mldata('MNIST original')
y = mnist.target
X = mnist.data
fig, ax = plt.subplots(2,5)
ax = ax.flatten()
for i in range(10):
im_idx = np.argwhere(y == i)[0]
print(im_idx)
plottable_image = np.reshape(X[im_idx], (28, 28))
ax[i].imshow(plottable_image, cmap='gray_r')
Output:
Upvotes: 4