shadowrain
shadowrain

Reputation: 163

visualize 10x10 grid of each digit using MNIST samples

I'm trying to plot 10x10 grid samples from the MNIST dataset. 10 of each digit. Here's the code:

Import libraries:

import sklearn
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator
from sklearn.pipeline import Pipeline
from sklearn.datasets import fetch_openml

Load digit data:

X, Y = fetch_openml(name='mnist_784', return_X_y=True, cache=False)

Plotting the grid:

def P1(num_examples=10):
plt.rc('image', cmap='Greys')
plt.figure(figsize=(num_examples,len(np.unique(Y))), dpi=X.shape[1])
# For each digit (from 0 to 9)
for i in np.nditer(np.unique(Y)):
    # Create a ndarray with the features of "num_examples" examples of digit "i"
    features = X[Y == i][:num_examples]
    # For each of the "num_examples" examples
    for j in range(num_examples):
        # Create subplot (from 1 to "num_digits"*"num_examples" of each digit)
        plt.subplot(len(np.unique(Y)), num_examples, i * num_examples + j + 1)
        plt.subplots_adjust(wspace=0, hspace=0)
        # Hide tickmarks and scale
        ax = plt.gca()
        # ax.set_axis_off() # Also hide axes (frame) 
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
        # Plot the corresponding digit (reshaped to square matrix/image)
        dim = int(np.sqrt(X.shape[1]))
        digit = features[j].reshape((dim,dim))            
        plt.imshow(digit)
P1(10)

However, I got an error message here saying that: "Iterator operand or requested dtype holds references, but the REFS_OK flag was not enabled"

Can anyone help me with this?

Upvotes: 0

Views: 3102

Answers (2)

Sindhu Tirth
Sindhu Tirth

Reputation: 1

It looks like you're encountering the error due to the way NumPy's nditer is used with non-numeric data types in your code. The Y array from MNIST dataset via fetch_openml is likely returned as a string array, and when you use np.nditer on np.unique(Y), it doesn't handle the string data correctly.

Removing the nditer like @stevemo suggested could help.

But if you are looking for an alternative implementation to display the mnist images in a grid, this function can be used directly to display mnist samples in a grid.

def display_sample(data,labels,index_range,grid_width=3):
"""
Displays the sample images with labels as a grid.

Parameters
    data: Image data array
    labels: Image label array (label indices must be same as the data array).
    index_range: Index range as list of two elements. First element is the start index, second element is the stop index.
    grid_width: Width of display grid i.e. number of columns in which images are to be displayed.

Usage example:
    start_index = 0
    stop_index = 10
    no_of_columns = 5
    display_sample(X_train, y_train,
                   [start_index, stop_index],
                   no_of_columns)
    >> displays images from index 0 to 10 from X_train with respective labels from y_train in a grid of 5 columns and 2 rows
"""
index_range = np.arange(index_range[0], index_range[1])
rows = int(len(index_range)/grid_width)
columns = grid_width
fig, ax = plt.subplots(rows, columns, figsize=(columns*5, rows*5))

index_counter = index_range[0]

for i in range(rows):
    for j in range(columns):
        if rows > 1:
            ax_id = ax[i,j]
        else:
            ax_id = ax[j]
        ax_id.imshow(data[index_counter])
        ax_id.set_title(class_names[labels[index_counter]], size=20)
        ax_id.axis("off")
        index_counter+=1
plt.show()

Upvotes: 0

stevemo
stevemo

Reputation: 1097

This error is coming from nd.iter most likely, which you don't need - also recommend using subplots and ax instead of MATLAB style plt calls:

digits = np.unique(Y)
M = 10
dim = int(np.sqrt(X.shape[1]))

fig, axs = plt.subplots(len(digits), M, figsize=(20,20))

for i,d in enumerate(digits):
    for j in range(M):
        axs[i,j].imshow(X[Y==d][j].reshape((dim,dim)))
        axs[i,j].axis('off')

Upvotes: 2

Related Questions