Dr. John
Dr. John

Reputation: 283

Pytorch - Subclasses of torchvision.dataset.ImageFolder - Import Error

Following my last post, I am now trying to implement a subclass of the torchvision.datasets.ImageFolder class. The following code returns an error ("name 'default_loader' is not defined"), and I can't figure out why. Will you please help me?

class ExtendingImageFolder(torchvision.datasets.ImageFolder)
   def __init__(self,root,transform=None, target_transform=None,loader=default_loader):
       super().__init__(root,transform,target_transform,loader)

When I delete the "None" and "default_loader", and write it like this;

    class ExtendingImageFolder(torchvision.datasets.ImageFolder)
   def __init__(self,root,transform, target_transform,loader):
       super().__init__(root,transform,target_transform,loader)

I get an error of missing input arguments when trying to create an instance of this class, like:

JJ=ExtendingImageFolder(root='C:/',transform=transform)

What am I doing wrong here?

Thanks in advance!

Upvotes: 2

Views: 1664

Answers (1)

benjaminplanche
benjaminplanche

Reputation: 15119

default_loader() is a function defined in torchvision/datasets/folder.py, along ImageFolder and other folder-based dataset helpers.

It is however not exported in torchvision/datasets/__init__.py (unlike ImageFolder). You can still import it directly with "from torchvision.datasets.folder import default_loader" - which should solve your import error.

Upvotes: 2

Related Questions