decipher
decipher

Reputation: 498

Reading Cifar10 dataset in batches

i am trying to read the CIFAR10 datasets, given in batches from https://www.cs.toronto.edu/~kriz/cifar.html>. i am trying to put it in a data frame using pickle and read 'data' part of it. But i am getting this error .

KeyError                                  Traceback (most recent call last)
<ipython-input-24-8758b7a31925> in <module>()
----> 1 unpickle('datasets/cifar-10-batches-py/test_batch')

<ipython-input-23-04002b89d842> in unpickle(file)
      3     fo = open(file, 'rb')
      4     dict = pickle.load(fo, encoding ='bytes')
----> 5     X = dict['data']
      6     fo.close()
      7     return dict

KeyError: 'data'.

i am using ipython and here is my code :

def unpickle(file):

 fo = open(file, 'rb')
 dict = pickle.load(fo, encoding ='bytes')
 X = dict['data']
 fo.close()
 return dict

unpickle('datasets/cifar-10-batches-py/test_batch')

Upvotes: 5

Views: 18902

Answers (6)

vxnuaj
vxnuaj

Reputation: 1

Not sure if this was solve for OP, and this reply is for anyone else who's looking for answer as I was recently.

All one needs to do is add 'b' when indexing the unpickled dict.

def unpickle(file):
    with open(file, 'rb') as f:
        data = pickle.load(f, encoding = 'bytes')
        return data
    
batch1 = unpickle('data/data_batch_1')


filenames = batch1[b'filenames'] #Add 'b' indicating that the dict holds binary values

'b' indicates to python that we're looking for a bytes literal.

Upvotes: 0

Dheemanth Bhat
Dheemanth Bhat

Reputation: 4452

This answer is based on Sohaib Anwaar's answer above, but with changes to obtain dataset as TensorFlow Dataset (tf.data.Dataset) instead of NumPy array.

Why TensorFlow Dataset?

tf.data.Datasets offers easy-to-use and high-performance input pipelines and is the "correct" way to access any dataset in TensorFlow 2.x.

Python version >= 3.10

For python version >= 3.10 the solution to obtain TensorFlow Dataset is very simple using tensorflow_datasets.

import tensorflow_datasets as tfds


(ds_train, ds_cval, ds_test), ds_info = tfds.load(
    "cifar10",
    split=["train[:75%]", "train[75%:]", "test"], 
    as_supervised=True,
    with_info=True
)

print("Train dataset size:", len(ds_train))
print("Cross-validation dataset size:", len(ds_cval))
print("Test dataset size:", len(ds_test))

Output

Train dataset size: 37500
Cross-validation dataset size: 12500
Test dataset size: 10000

Python version <= 3.9.x

After downloading CIFAR-10 dataset, extract the tar.gz contents to a folder named data.

Changes from accepted answer

  1. load_CIFAR10 returns TensorFlow Dataset instead of NumPy array.
  2. load_CIFAR10 splits dataset into Train, Cross-Validation and Test sets.
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf

import os
import math
import platform
import pickle
def load_CIFAR_batch(file_path):
    """
    Load single batch of CIFAR-10 images from
    the binary file and return as a NumPy array.
    """
    with open(file_path, "rb") as f:
        data_dict = pickle.load(f, encoding="latin1")

        # Extract NumPy from dictionary.
        X = data_dict["data"]
        y = data_dict["labels"]

        # Reshape and transpose flat array as 32 X 32 RGB image.
        X = X.reshape(CIFAR_BATCH_SIZE, *input_shape, order="F")
        X = X.transpose((0, 2, 1, 3))

        # Convert `labels` to vector.
        y = np.expand_dims(y, axis=1)

        return X, y
def load_CIFAR10(cv_size=0.25):
    """
    Load all batches of CIFAR-10 images from the
    binary file and return as TensorFlow DataSet.
    """
    X_btchs = []
    y_btchs = []
    for batch in range(1, 6):
        file_path = os.path.join(ROOT, "data_batch_%d" % (batch,))
        X, y = load_CIFAR_batch(file_path)
        X_btchs.append(X)
        y_btchs.append(y)

    # Combine all batches.
    all_Xbs = np.concatenate(X_btchs)
    all_ybs = np.concatenate(y_btchs)

    # Convert Train dataset from NumPy array to TensorFlow Dataset.
    ds_all = tf.data.Dataset.from_tensor_slices((all_Xbs, all_ybs))

    al_size = len(ds_all)
    tr_size = math.ceil((1 - cv_size) * al_size)
    # Split dataset into Train and Cross-validation sets.
    ds_tr = ds_all.take(tr_size)
    ds_cv = ds_all.skip(tr_size)
    print(f"Train dataset size: {tr_size}.")
    print(f"Cross-validation dataset size: {al_size - tr_size}.")

    # Convert Test dataset from NumPy array to TensorFlow Dataset.
    X_ts, y_ts = load_CIFAR_batch(os.path.join(ROOT, "test_batch"))
    ds_ts = tf.data.Dataset.from_tensor_slices((X_ts, y_ts))
    print(f"Test dataset size {len(ds_ts)}.")

    return ds_tr, ds_cv, ds_ts
ROOT = "../data/cifar-10-batches-py/"

CIFAR_BATCH_SIZE = 10000
img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)

ds_tr, ds_cv, ds_ts = load_CIFAR10()

Output

Train dataset size: 37500.
Cross-validation dataset size: 12500.
Test dataset size 10000.

Confirm Dataset

xi, yi = ds_tr.as_numpy_iterator().next()

plt.imshow(xi)
plt.title(f"Class label: {yi[0]}")
plt.show()

Frog image

P.S.
Above image is very low quality 32X32 dimension image of a frog.

Upvotes: 0

ndrplz
ndrplz

Reputation: 1654

I went through similar issues in the past.

I'd like to mention for future readers that you can find here a python wrapper for automatically downloading, extracting and parsing the cifar10 dataset.

Upvotes: 0

Sohaib Anwaar
Sohaib Anwaar

Reputation: 1547

you can read cifar 10 datasets by the code given below only make sure that you are giving write directory where the batches are placed

import tensorflow as tf
import pandas as pd
import numpy as np
import math
import timeit
import matplotlib.pyplot as plt
from six.moves import cPickle as pickle
import os
import platform
from subprocess import check_output
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

%matplotlib inline


img_rows, img_cols = 32, 32
input_shape = (img_rows, img_cols, 3)
def load_pickle(f):
    version = platform.python_version_tuple()
    if version[0] == '2':
        return  pickle.load(f)
    elif version[0] == '3':
        return  pickle.load(f, encoding='latin1')
    raise ValueError("invalid python version: {}".format(version))

def load_CIFAR_batch(filename):
    """ load single batch of cifar """
    with open(filename, 'rb') as f:
        datadict = load_pickle(f)
        X = datadict['data']
        Y = datadict['labels']
        X = X.reshape(10000,3072)
        Y = np.array(Y)
        return X, Y

def load_CIFAR10(ROOT):
    """ load all of cifar """
    xs = []
    ys = []
    for b in range(1,6):
        f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
        X, Y = load_CIFAR_batch(f)
        xs.append(X)
        ys.append(Y)
    Xtr = np.concatenate(xs)
    Ytr = np.concatenate(ys)
    del X, Y
    Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
    return Xtr, Ytr, Xte, Yte
def get_CIFAR10_data(num_training=49000, num_validation=1000, num_test=10000):
    # Load the raw CIFAR-10 data
    cifar10_dir = '../input/cifar-10-batches-py/'
    X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)

    # Subsample the data
    mask = range(num_training, num_training + num_validation)
    X_val = X_train[mask]
    y_val = y_train[mask]
    mask = range(num_training)
    X_train = X_train[mask]
    y_train = y_train[mask]
    mask = range(num_test)
    X_test = X_test[mask]
    y_test = y_test[mask]

    x_train = X_train.astype('float32')
    x_test = X_test.astype('float32')

    x_train /= 255
    x_test /= 255

    return x_train, y_train, X_val, y_val, x_test, y_test


# Invoke the above function to get our data.
x_train, y_train, x_val, y_val, x_test, y_test = get_CIFAR10_data()


print('Train data shape: ', x_train.shape)
print('Train labels shape: ', y_train.shape)
print('Validation data shape: ', x_val.shape)
print('Validation labels shape: ', y_val.shape)
print('Test data shape: ', x_test.shape)
print('Test labels shape: ', y_test.shape)

Upvotes: 11

Jingnan Jia
Jingnan Jia

Reputation: 1309

I know the reason! I had the same problem and I solved it ! The key problem is about the encoding method, change the code from

dict = pickle.load(fo, encoding ='bytes')

to

dict = pickle.load(fo, encoding ='latin1')

Upvotes: 2

Dixith
Dixith

Reputation: 1

Try this

def unpickle(file): import cPickle with open(file, 'rb') as fo: data = cPickle.load(fo) return data

Upvotes: -1

Related Questions