Trong Van
Trong Van

Reputation: 382

Fixing error img should be PIL Image. Got <class 'torch.Tensor'>

I tried to create Custom dataset but when show the some images it had the error. Here is my Dataset class and tranforms:

transforms = transforms.Compose([transforms.Resize(224,224)])

class MyDataset(Dataset):
    def __init__(self, path, label,  transform=None):
        self.path = glob.glob(os.path.join(path, '*.jpg'))
        self.transform = transform
        self.label = label

    def __getitem__(self, index):
        img = io.imread(self.path[index])
        img = torch.tensor(img)
        labels = torch.tensor(int(self.label))
        if self.transform:
          img = self.transform(img)
        return (img,labels)
    
    def __len__(self):
        return len(self.path)

And here error line:

images, labels = next(iter(train_loader))

Upvotes: 2

Views: 2611

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24701

transforms.Resize requires PIL.Image instance as input while your img is a torch.Tensor.

This will solve your issue (see comments in source code):

import torchvision
from PIL import Image

# In your transform you should cast PIL Image to tensor
# When no transforms on PIL Image are needed anymore
transforms = transforms.Compose([transforms.Resize(224, 224), transforms.ToTensor()])


class MyDataset(Dataset):
    def __init__(self, path, label, transform=None):
        self.path = glob.glob(os.path.join(path, "*.jpg"))
        self.transform = transform
        self.label = label

    def __getitem__(self, index):
        img = Image.open(self.path[index])
        labels = torch.tensor(int(self.label))
        if self.transform is not None:
            # Now you have PIL.Image instance for transforms
            img = self.transform(img)
        return (img, labels)

    def __len__(self):
        return len(self.path)

Upvotes: 3

Related Questions