Confused
Confused

Reputation: 1

Average image of MNIST

Working with the MNIST dataset, I am trying to find the average image for each distinct digit (0-9). The following code gives me each distinct image from the dataset, but I am not sure how I would get the mean for each class (0-9)

data = io.loadmat('mnist-original.mat')

x, y = data['data'].T, data['label'].T

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.5)


a=np.unique(y, return_index=True)
b = a[1]

plt.figure(figsize=(15,4.5))
for i in b:
    img=x[i][:].reshape(28,28)
    plt.imshow(img)
    plt.show()  

Upvotes: 0

Views: 2183

Answers (2)

Jason
Jason

Reputation: 21

Suppose the "average" image for zero is the average of all training data with label = 0. For example:

avgImg = np.average(x_train[y_train==0],0)

I think this is what you want:

import matplotlib.pyplot as plt
import numpy as np

plt.figure(figsize=(10,3))
for i in range(10):
    avgImg = np.average(x_train[y_train==i],0)
    plt.subplot(2, 5, i+1)
    plt.imshow(avgImg.reshape((16,16))) 
    plt.axis('off')

Upvotes: 2

Eelco Hoogendoorn
Eelco Hoogendoorn

Reputation: 10759

The numpy_indexed package (disclaimer: I am its author) provides this type of functionality in a vectorized manner:

import numpy_indexed as npi
digits, means = npi.group_by(y).mean(x)

Upvotes: 0

Related Questions