arush1836
arush1836

Reputation: 1547

Trouble plotting MNIST digits

I am trying to load and visualize MNIST digits, but I am getting digits with shifted pixel

import matplotlib.pyplot as plt
import numpy as np

mnist_data  = open('data/mnist/train-images-idx3-ubyte', 'rb')

image_size = 28
num_images = 4

buf = mnist_data.read(num_images * image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size, image_size)

_, axarr1 = plt.subplots(2,2)
axarr1[0, 0].imshow(data[0])
axarr1[0, 1].imshow(data[1])
axarr1[1, 0].imshow(data[2])
axarr1[1, 1].imshow(data[3])

MNIST

Can anyone tell me why it's happening code seems fine, Thank you

Upvotes: 3

Views: 236

Answers (1)

Diziet Asahi
Diziet Asahi

Reputation: 40687

You don't say where you obtained the MNIST data, but, if it is formatted like the original data set, you seem to have forgotten to extract the header before trying to access the data:

image_size = 28
num_images = 4

mnist_data = open('train-images-idx3-ubyte', 'rb')

mnist_data.seek(16) # skip over the first 16 bytes that correspond to the header
buf = mnist_data.read(num_images * image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size, image_size)

_, axarr1 = plt.subplots(2,2)
axarr1[0, 0].imshow(data[0])
axarr1[0, 1].imshow(data[1])
axarr1[1, 0].imshow(data[2])
axarr1[1, 1].imshow(data[3])

enter image description here

Upvotes: 2

Related Questions