Nicolas Gervais
Nicolas Gervais

Reputation: 36624

How do I turn a Tensorflow Dataset into a Numpy Array?

I'm interested in a Tensorflow Dataset, but I want to manipulate it using numpy. Is it possible to turn this PrefetchDataset into an array?

import tensorflow_datasets as tfds
import numpy as np

dataset = tfds.load('mnist')

Upvotes: 1

Views: 4495

Answers (1)

Nicolas Gervais
Nicolas Gervais

Reputation: 36624

Since you didn't specify split or as_supervised, tfds will return a dictionary with train and test set. Since as_supervised defaults to False, the image and label will also be separate in a dictionary. This is what it will look like:

{'test': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, 
    types: {image: tf.uint8, label: tf.int64}>,
 'train': <PrefetchDataset shapes: {image: (28, 28, 1), label: ()}, 
    types: {image: tf.uint8, label: tf.int64}>}

So here's how you can turn it into a numpy array:

import tensorflow_datasets as tfds
import numpy as np

dataset = tfds.load('mnist')

train, test = dataset['train'], dataset['test']

train_numpy = np.vstack(tfds.as_numpy(train))
test_numpy = np.vstack(tfds.as_numpy(test))

X_train = np.array(list(map(lambda x: x[0]['image'], train_numpy)))
y_train = np.array(list(map(lambda x: x[0]['label'], train_numpy)))

X_test = np.array(list(map(lambda x: x[0]['image'], test_numpy)))
y_test = np.array(list(map(lambda x: x[0]['label'], test_numpy)))

You might want to set as_supervised=True, which will return tuple instead of a dictionary for image and label.

[<PrefetchDataset shapes: ((28, 28, 1), ()), types: (tf.uint8, tf.int64)>]

In this case, you will need to select the 'image' and 'label' using indexing like [0]. So here's how you can turn it into a numpy array:

import tensorflow_datasets as tfds
import numpy as np

dataset = tfds.load('mnist', split=['test'], as_supervised=True)

array = np.vstack(tfds.as_numpy(dataset[0]))

X_train = np.array(list(map(lambda x: x[0], array)))
y_train = np.array(list(map(lambda x: x[1], array)))

Proof:

X_train.shape
(10000, 28, 28, 1)

Upvotes: 1

Related Questions