Reputation: 81
I am learning how to create a MNIST model from scratch in tensorflow 2.0 and Keras from a Udemy course.
So, I got the mnist dataset as follows
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
and everything was fine, even I got 97% accuracy testing my model and I was happy.
The problem started when I tried to do something different from the course. I tried to print some examples from mnist_dataset using matplotlib plt.imshow()
and I totally failed. Then I started some research and I got a solution, I needed to get the dataset like this:
mnist_dataset2 = tfds.load(name = 'mnist')
mnistt = mnist_dataset2['train']
where mnistt
is the dataset I can manipulate and print using matplotlib.
So my question is as follows: where can I get information about types of tfds.load() you can get and how to correctly manipulate them as you want? (and being somewhat extendible from a beginner in tensorflow like me).
Upvotes: 1
Views: 5893
Reputation:
try this
x_train, y_train = Next(iter(mnist_train))
then plot x_train
Upvotes: 0
Reputation: 27042
The main invocation of the tfds.load
method contains everything you need:
mnist_dataset, mnist_info = tfds.load(name = 'mnist', with_info=True, as_supervised=True)
name="mnist"
-> you're specifiying the builder you want to use (mnist)with_info=True
-> you're asking tfds.load
to return the info
object that contains all you need to know about the returned datasetas_supervised=True
-> you're asking tfds.load
to get only the elements of the dataset needed for a supervised learning task (the image and label pair).Your first attempt of using mnist_dataset
to get the data (to use with matplotlib
) failed because as you can see from
print(mnist_info) #run me!
The dataset contains 2 different splits: train
and test
.
tfds.core.DatasetInfo(
name='mnist',
version=1.0.0,
description='The MNIST database of handwritten digits.',
urls=['https://storage.googleapis.com/cvdf-datasets/mnist/'],
features=FeaturesDict({
'image': Image(shape=(28, 28, 1), dtype=tf.uint8),
'label': ClassLabel(shape=(), dtype=tf.int64, num_classes=10),
}),
total_num_examples=70000,
splits={
'test': 10000,
'train': 60000,
},
supervised_keys=('image', 'label'),
citation="""@article{lecun2010mnist,
title={MNIST handwritten digit database},
author={LeCun, Yann and Cortes, Corinna and Burges, CJ},
journal={ATT Labs [Online]. Available: http://yann. lecun. com/exdb/mnist},
volume={2},
year={2010}
}""",
redistribution_info=,
)
Thus, the object returned by tfds.load
is a dictionary:
{
"train": <train dataset>,
"test": <test dataset>
}
In fact, in the next line of the example, you extract the "train" and "test" datasets in this way:
mnist_train, mnist_test = mnist_dataset['train'], mnist_dataset['test']
From the mnist_info
object, you can get every info you need to manipulate your dataset: the number of splits, the data type (e.g. "image" is a 28x28x1 image with dtype tf.uint8), etc...
Upvotes: 0