Cristi Vlad
Cristi Vlad

Reputation: 81

Have an Issue Implementing CIFAR10 in Tensorflow

import os
import numpy as np
import pickle

class CifarLoader(object):
    def __init__(self, source_files):
        self._source = source_files
        self._i = 0
        self.images = None
        self.labels = None

    def load(self):
        data = [unpickle(f) for f in self._source] #again a list comprehension
        images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
        n = len(images)
        self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel
        self.labels = one_hot(np.hstack([d["labels"] for d in data]), 10)
        return self

    def next_batch(self, batch_size):
        x, y = self.images[self._i:self._i+batch_size], self.labels[self._i:self._i+batch_size]
        self._i = (sel._i + batch_size) % len(self.images)
        return x, y

DATA_PATH = "cifar10"

def unpickle(file):
    with open(os.path.join(DATA_PATH, file), 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

def one_hot(vec, vals=10):
    n = len(vec)
    out = np.zeros((n, vals))
    out[range(n), vec] = 1
    return out

class CifarDataManager(object):
    def __init__(self):
        self.train = CifarLoader(["data_batch_{}".format(i) for i in range(1, 6)]).load()
        self.test = CifarLoader(["test_batch"]).load()

def display_cifar(images, size):
    n = len(images)
    plt.figure()
    plt.gca().set_axis_off()
    im = np.vstack([np.hstack([images[np.random.choice(n)] for i in range(size)]) for i in range(size)])
    plt.imshow(im)
    plt.show()

d = CifarDataManager()

print ("Number of train images: {}".format(len(d.train.images)))
print ("Number of train labels: {}".format(len(d.train.labels)))
print ("Number of test images: {}".format(len(d.test.images)))
print ("Number of test images: {}".format(len(d.test.labels)))
images = d.train.images
display_cifar(images, 10)

And this is the error I'm getting.

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-182-b3f5a6bd2e1d> in <module>()
      7     plt.show()
      8 
----> 9 d = CifarDataManager()
     10 
     11 print ("Number of train images: {}".format(len(d.train.images)))

<ipython-input-181-e85d41d02848> in __init__(self)
      1 class CifarDataManager(object):
      2     def __init__(self):
----> 3         self.train = CifarLoader(["data_batch_{}".format(i) for i in range(1, 6)]).load()
      4         self.test = CifarLoader(["test_batch"]).load()

<ipython-input-179-d96c4afcda51> in load(self)
     12     def load(self):
     13         data = [unpickle(f) for f in self._source] #again a list comprehension
---> 14         images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
     15         n = len(images)
     16         self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel

<ipython-input-179-d96c4afcda51> in <listcomp>(.0)
     12     def load(self):
     13         data = [unpickle(f) for f in self._source] #again a list comprehension
---> 14         images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
     15         n = len(images)
     16         self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel

KeyError: 'data'

Any help is appreciated! I suspect the issue has to do with pickle and Python3 and the way it loads the data.

Upvotes: 0

Views: 92

Answers (1)

Y. Luo
Y. Luo

Reputation: 5722

Thank you for checking out your files and posting the result. It is clear now your key is bytes string (bytes). Since you didn't specify, I can only guess you are using python3 which can't convert bytes object to string implicitly (see the note in this section). Try the following under python 2 and python 3, and you may have a better idea:

d = {b'a': 1, b'b': 2}
print(d.keys())
try:
    print('Key "a" gives: {}'.format(d["a"]))
except Exception as err:
    print('Get "{}"!'.format(err.__class__.__name__))
    print('Key b"a" gives: {}'.format(d[b"a"]))

Upvotes: 1

Related Questions