user11146622
user11146622

Reputation:

TypeError: Cannot handle this data type

Trying to put the saliency map to the image and make a new data set

trainloader = utilsxai.load_data_cifar10(batch_size=1,test=False)
testloader =  utilsxai.load_data_cifar10(batch_size=128, test=True)

this load_cifar10 is torchvision

data = trainloader.dataset.data 

trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape)

sal_maps_hf shape with (50000,32,32,3)
and trainloader shape with (50000,32,32,3)

but when I run this

for idx,img in enumerate(trainloader):
--------------------------------------------------------------------------- 
KeyError                                  Traceback (most recent call
last) ~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode)    2644             typekey = (1, 1) + shape[2:], arr["typestr"]
-> 2645             mode, rawmode = _fromarray_typemap[typekey]    2646         except KeyError:

KeyError: ((1, 1, 3), '<f4')

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last) <ipython-input-142-9410d0967245> in <module>
----> 1 show_images(trainloader)

<ipython-input-117-a32f5bd33032> in show_images(trainloader)
      1 def show_images(trainloader):
----> 2     for idx,(img,target) in enumerate(trainloader):
      3         img = img.squeeze()
      4         #pritn(img)
      5         img = torch.tensor(img)

~/venv/lib/python3.7/site-packages/torch/utils/data/dataloader.py in
__next__(self)
    344     def __next__(self):
    345         index = self._next_index()  # may raise StopIteration
--> 346         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    347         if self._pin_memory:
    348             data = _utils.pin_memory.pin_memory(data)

~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/venv/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

~/venv/lib/python3.7/site-packages/torchvision/datasets/cifar.py in
__getitem__(self, index)
    120         # doing this so that it is consistent with all other datasets
    121         # to return a PIL Image
--> 122         img = Image.fromarray(img)
    123 
    124         if self.transform is not None:

~/venv/lib/python3.7/site-packages/PIL/Image.py in fromarray(obj, mode)    2645             mode, rawmode = _fromarray_typemap[typekey]  2646         except KeyError:
-> 2647             raise TypeError("Cannot handle this data type")    2648     else:    2649         rawmode = mode

TypeError: Cannot handle this data type


    trainloader.dataset.__getitem__

<bound method CIFAR10.__getitem__ of Dataset CIFAR10
    Number of datapoints: 50000
    Root location: /mnt/3CE35B99003D727B/input/pytorch/data
    Split: Train
    StandardTransform Transform: Compose(
               Resize(size=32, interpolation=PIL.Image.BILINEAR)
               ToTensor()
           )

Upvotes: 1

Views: 1338

Answers (1)

Shai
Shai

Reputation: 114786

Your sal_maps_hf is not np.uint8.

Based on the partial information in the question and in comments, I guess that your mask is of dtype np.float (or similar), and by multiplying data * sal_maps_hf your data is cast to dtype other than np.uint8 which later makes PIL.Image to throw an exception.

Try:

trainloader.dataset.data = (data * sal_maps_hf).reshape(data.shape).astype(np.uint8)

Upvotes: 1

Related Questions