Reputation: 1659
I'm importing mnist dataset from Keras using (x_train, y_train), (x_test, y_test) = mnist.load_data()
and what I want to do is sort each sample by it's corresponding digit. I'm imagining some trivial way to do this but I can't seem to find any label attribute of the data. Any simple way to do this?
Upvotes: 3
Views: 3496
Reputation: 10211
y_train
and y_test
are the vectors containing the label associated with each image in x_train and x_test respectively. That will tell you the digit shown in each image. So just get the indices that will sort these vectors using np.argsort
and then use these indices to re-order the corresponding matrix.
import numpy as np
idx = np.argsort(y_train)
x_train_sorted = x_train[idx]
y_train_sorted = y_train[idx]
So if you want all the images for a particular digit, you can simply grab them by indexing the corresponding matrix
x_train_zeros = x_train[y_train == 0]
x_train_ones = x_train[y_train == 1]
# and so on...
Notice that in this case you don't need to pre-sort the data.
Upvotes: 4