njho
njho

Reputation: 2158

How does `images, labels = dataiter.next() ` work in PyTorch Tutorial?

From the tutorial cifar10_tutorial, how is images, labels assigned?

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=2)

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))


# get some random training images
dataiter = iter(trainloader)

images, labels = dataiter.next()

How does the last line know how to automatically assign images, label in images, labels = dataiter.next()?

I checked the DataLoader class and the DataLoaderIter class, but think I need a bit more knowledge of iters in general.

Upvotes: 6

Views: 12081

Answers (1)

cvanelteren
cvanelteren

Reputation: 1703

I think it is crucial to understand the difference between an iterable and an iterator. An iterable is an object that you can iterate over. An Iterator is an object which is used to iterate over an iterable object using the __next__ method, which returns the next item of the object.

A simple example is the following. Consider an iterable and use the next method to call the next item in the list. This will print the next item until the end of the list is reached. If the end is reached it will raise a StopIteration error.

test = (1,2,3)
tester = iter(test)

while True:
    nextItem = next(tester)
    print(nextItem)

The class you refer to above probably has an implementation similar to this, however it returns a tuple containing the image and the label.

Upvotes: 8

Related Questions