Reputation: 382
I tried to create Custom dataset but when show the some images it had the error. Here is my Dataset class and tranforms:
transforms = transforms.Compose([transforms.Resize(224,224)])
class MyDataset(Dataset):
def __init__(self, path, label, transform=None):
self.path = glob.glob(os.path.join(path, '*.jpg'))
self.transform = transform
self.label = label
def __getitem__(self, index):
img = io.imread(self.path[index])
img = torch.tensor(img)
labels = torch.tensor(int(self.label))
if self.transform:
img = self.transform(img)
return (img,labels)
def __len__(self):
return len(self.path)
And here error line:
images, labels = next(iter(train_loader))
Upvotes: 2
Views: 2611
Reputation: 24701
transforms.Resize
requires PIL.Image
instance as input
while your img
is a torch.Tensor
.
This will solve your issue (see comments in source code):
import torchvision
from PIL import Image
# In your transform you should cast PIL Image to tensor
# When no transforms on PIL Image are needed anymore
transforms = transforms.Compose([transforms.Resize(224, 224), transforms.ToTensor()])
class MyDataset(Dataset):
def __init__(self, path, label, transform=None):
self.path = glob.glob(os.path.join(path, "*.jpg"))
self.transform = transform
self.label = label
def __getitem__(self, index):
img = Image.open(self.path[index])
labels = torch.tensor(int(self.label))
if self.transform is not None:
# Now you have PIL.Image instance for transforms
img = self.transform(img)
return (img, labels)
def __len__(self):
return len(self.path)
Upvotes: 3