Reputation:
Trying to put the saliency map to the image and make a new data set
trainloader = utilsxai.load_data_cifar10(batch_size=1,test=False)
testloader = utilsxai.load_data_cifar10(batch_size=128, test=True)
this load_cifar10 is torchvision
data = trainloader.dataset.data
trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape)
sal_maps_hf shape with (50000,32,32,3)
and trainloader shape with (50000,32,32,3)
but when I run this
for idx,img in enumerate(trainloader):
---------------------------------------------------------------------------
KeyError Traceback (most recent call
last) ~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode) 2644 typekey = (1, 1) + shape[2:], arr["typestr"]
-> 2645 mode, rawmode = _fromarray_typemap[typekey] 2646 except KeyError:
KeyError: ((1, 1, 3), '<f4')
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last) <ipython-input-142-9410d0967245> in <module>
----> 1 show_images(trainloader)
<ipython-input-117-a32f5bd33032> in show_images(trainloader)
1 def show_images(trainloader):
----> 2 for idx,(img,target) in enumerate(trainloader):
3 img = img.squeeze()
4 #pritn(img)
5 img = torch.tensor(img)
~/venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py in
__next__(self)
344 def __next__(self):
345 index = self._next_index() # may raise StopIteration
--> 346 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
347 if self._pin_memory:
348 data = _utils.pin_memory.pin_memory(data)
~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
~/venv/lib/python3.7/site-packages/torchvision/datasets/cifar.py in
__getitem__(self, index)
120 # doing this so that it is consistent with all other datasets
121 # to return a PIL Image
--> 122 img = Image.fromarray(img)
123
124 if self.transform is not None:
~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode) 2645 mode, rawmode = _fromarray_typemap[typekey] 2646 except KeyError:
-> 2647 raise TypeError("Cannot handle this data type") 2648 else: 2649 rawmode = mode
TypeError: Cannot handle this data type
trainloader.dataset.__getitem__
<bound method CIFAR10.__getitem__ of Dataset CIFAR10
Number of datapoints: 50000
Root location: /mnt/3CE35B99003D727B/input/pytorch/data
Split: Train
StandardTransform Transform: Compose(
Resize(size=32, interpolation=PIL.Image.BILINEAR)
ToTensor()
)
Upvotes: 1
Views: 1338
Reputation: 114786
Your sal_maps_hf
is not np.uint8
.
Based on the partial information in the question and in comments, I guess that your mask is of dtype
np.float
(or similar), and by multiplying data * sal_maps_hf
your data is cast to dtype
other than np.uint8
which later makes PIL.Image
to throw an exception.
Try:
trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape).astype(np.uint8)
Upvotes: 1