Reputation: 1547
I am trying to load and visualize MNIST digits, but I am getting digits with shifted pixel
import matplotlib.pyplot as plt
import numpy as np
mnist_data = open('data/mnist/train-images-idx3-ubyte', 'rb')
image_size = 28
num_images = 4
buf = mnist_data.read(num_images * image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size, image_size)
_, axarr1 = plt.subplots(2,2)
axarr1[0, 0].imshow(data[0])
axarr1[0, 1].imshow(data[1])
axarr1[1, 0].imshow(data[2])
axarr1[1, 1].imshow(data[3])
Can anyone tell me why it's happening code seems fine, Thank you
Upvotes: 3
Views: 236
Reputation: 40687
You don't say where you obtained the MNIST data, but, if it is formatted like the original data set, you seem to have forgotten to extract the header before trying to access the data:
image_size = 28
num_images = 4
mnist_data = open('train-images-idx3-ubyte', 'rb')
mnist_data.seek(16) # skip over the first 16 bytes that correspond to the header
buf = mnist_data.read(num_images * image_size * image_size)
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float32)
data = data.reshape(num_images, image_size, image_size)
_, axarr1 = plt.subplots(2,2)
axarr1[0, 0].imshow(data[0])
axarr1[0, 1].imshow(data[1])
axarr1[1, 0].imshow(data[2])
axarr1[1, 1].imshow(data[3])
Upvotes: 2