mats
mats

Reputation: 1

How to display single MNIST digits, each in one row?

I would like to display the single digits from 0-9, each in one row and 5 examples for each digit. The photo is an example: example In my case that makes 10 rows, 5 columns.

I managed to display the first 50 images like so (sorry for the image, I was not able to format the code at stackoverflow):code

How can I print each digit with its label in one row? I have tried around already for hours, but I have no idea, except of using numpy.take(), but I don't know how. I have googled a lot already with no usable result.

Thanks in advance!

Upvotes: 0

Views: 2101

Answers (1)

TBM - VOICE
TBM - VOICE

Reputation: 163

First you need the dataset:

dataset = keras.datasets.mnist.load_data()

Then you split it:

X_train = dataset[0][0]
y_train = dataset[0][1]
X_test = dataset[1][0]
y_test = dataset[1][1]

Then you may create a dictionary of indexes of digits

Here I use test dataset. You can use train if you want, just replace y_train:

digits = {}

for i in range(10):
    digits[i] = np.where(y_test==i)[0][:5]

digits

This dict will look like this:

{0: array([ 3, 10, 13, 25, 28], dtype=int64),
 1: array([ 2,  5, 14, 29, 31], dtype=int64),
 2: array([ 1, 35, 38, 43, 47], dtype=int64),
 3: array([18, 30, 32, 44, 51], dtype=int64),
 4: array([ 4,  6, 19, 24, 27], dtype=int64),
 5: array([ 8, 15, 23, 45, 52], dtype=int64),
 6: array([11, 21, 22, 50, 54], dtype=int64),
 7: array([ 0, 17, 26, 34, 36], dtype=int64),
 8: array([ 61,  84, 110, 128, 134], dtype=int64),
 9: array([ 7,  9, 12, 16, 20], dtype=int64)}

At last you create a figure and subplots like this:

import matplotlib.pyplot as plt
fig, ax = plt.subplots(10, 5, sharex='col', sharey='row')
for i in range(10):
    for j in range(5):
        ax[i, j].imshow(X_test[digits[i][j]])

Result

Upvotes: 7

Related Questions