Reputation: 81
import os
import numpy as np
import pickle
class CifarLoader(object):
def __init__(self, source_files):
self._source = source_files
self._i = 0
self.images = None
self.labels = None
def load(self):
data = [unpickle(f) for f in self._source] #again a list comprehension
images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
n = len(images)
self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel
self.labels = one_hot(np.hstack([d["labels"] for d in data]), 10)
return self
def next_batch(self, batch_size):
x, y = self.images[self._i:self._i+batch_size], self.labels[self._i:self._i+batch_size]
self._i = (sel._i + batch_size) % len(self.images)
return x, y
DATA_PATH = "cifar10"
def unpickle(file):
with open(os.path.join(DATA_PATH, file), 'rb') as fo:
dict = pickle.load(fo, encoding='bytes')
return dict
def one_hot(vec, vals=10):
n = len(vec)
out = np.zeros((n, vals))
out[range(n), vec] = 1
return out
class CifarDataManager(object):
def __init__(self):
self.train = CifarLoader(["data_batch_{}".format(i) for i in range(1, 6)]).load()
self.test = CifarLoader(["test_batch"]).load()
def display_cifar(images, size):
n = len(images)
plt.figure()
plt.gca().set_axis_off()
im = np.vstack([np.hstack([images[np.random.choice(n)] for i in range(size)]) for i in range(size)])
plt.imshow(im)
plt.show()
d = CifarDataManager()
print ("Number of train images: {}".format(len(d.train.images)))
print ("Number of train labels: {}".format(len(d.train.labels)))
print ("Number of test images: {}".format(len(d.test.images)))
print ("Number of test images: {}".format(len(d.test.labels)))
images = d.train.images
display_cifar(images, 10)
And this is the error I'm getting.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
<ipython-input-182-b3f5a6bd2e1d> in <module>()
7 plt.show()
8
----> 9 d = CifarDataManager()
10
11 print ("Number of train images: {}".format(len(d.train.images)))
<ipython-input-181-e85d41d02848> in __init__(self)
1 class CifarDataManager(object):
2 def __init__(self):
----> 3 self.train = CifarLoader(["data_batch_{}".format(i) for i in range(1, 6)]).load()
4 self.test = CifarLoader(["test_batch"]).load()
<ipython-input-179-d96c4afcda51> in load(self)
12 def load(self):
13 data = [unpickle(f) for f in self._source] #again a list comprehension
---> 14 images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
15 n = len(images)
16 self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel
<ipython-input-179-d96c4afcda51> in <listcomp>(.0)
12 def load(self):
13 data = [unpickle(f) for f in self._source] #again a list comprehension
---> 14 images = np.vstack([d["data"] for d in data]) #so vstack stacks these arrays in sequence vertically or row wise
15 n = len(images)
16 self.images = images.reshape(n, 3, 32, 32).transpose(0, 2, 3, 1).astype(float)/255 #number of possible shades for each channel
KeyError: 'data'
Any help is appreciated! I suspect the issue has to do with pickle and Python3 and the way it loads the data.
Upvotes: 0
Views: 92
Reputation: 5722
Thank you for checking out your files and posting the result. It is clear now your key is bytes string (bytes). Since you didn't specify, I can only guess you are using python3 which can't convert bytes object to string implicitly (see the note in this section). Try the following under python 2 and python 3, and you may have a better idea:
d = {b'a': 1, b'b': 2}
print(d.keys())
try:
print('Key "a" gives: {}'.format(d["a"]))
except Exception as err:
print('Get "{}"!'.format(err.__class__.__name__))
print('Key b"a" gives: {}'.format(d[b"a"]))
Upvotes: 1