Raj Rajeshwari Prasad
Raj Rajeshwari Prasad

Reputation: 334

Reshaping image and Plotting in Python

I am working on mnist_fashion data. The images in mnist_data are 28x28 pixel. For the purpose of feeding it to a neural network(multi-layer perceptron), I transformed the data into (784,) shape.

Further, I need to again reshape it back to the original size.

For this, I used below given code:-

from keras.datasets import fashion_mnist
import numpy as np
import matplotlib.pyplot as plt


(train_imgs,train_lbls), (test_imgs, test_lbls) = fashion_mnist.load_data()
plt.imshow(test_imgs[0].reshape(28,28))

no_of_test_imgs  = test_imgs.shape[0]

test_imgs_trans  = test_imgs.reshape(test_imgs.shape[1]*test_imgs.shape[2], no_of_test_imgs).T

plt.imshow(test_imgs_trans[0].reshape(28,28))

Unfortunately, I am not getting the similar image. I am not able to understand why this is happening.

expected image: enter image description here

recieved image:enter image description here

Kindly help me to resolve the problem.

Upvotes: 2

Views: 2268

Answers (2)

Marlon
Marlon

Reputation: 1

I had a similar problem in plotting the images. My image size was 224*224, and the total number of images was 1100. I wanted to plot a few of the typical ten of misclassified images. False_ class contains the misclassified images.

import matplotlib.pyplot as plt
plt.imshow(np.reshape(false_class[i][0], (-1, 300)))

The plot gave images with a lot of self-folding, thus distorting the image. The problem was rectified by the statement

plt.imshow(np.reshape(false_class[i][0], (-1, 672)))

This statement gave a proper image without any distortion, the reason behind this is the total array length was 150528 as the image is 224 * 224 * 3 ( as it is RGB image). Hence the 672 has to be a multiple 224.

Upvotes: 0

Marco Cerliani
Marco Cerliani

Reputation: 22031

pay attention when you flatten the images in test_imgs_trans

(train_imgs,train_lbls), (test_imgs, test_lbls) = tf.keras.datasets.fashion_mnist.load_data()

plt.imshow(test_imgs[0].reshape(28,28))

no_of_test_imgs  = test_imgs.shape[0]

test_imgs_trans  = test_imgs.reshape(no_of_test_imgs, test_imgs.shape[1]*test_imgs.shape[2])

plt.imshow(test_imgs_trans[0].reshape(28,28))

Upvotes: 1

Related Questions