blue-sky
blue-sky

Reputation: 53916

Understanding PyTorch training batches

Reading https://stanford.edu/~shervine/blog/pytorch-how-to-generate-data-parallel & https://discuss.pytorch.org/t/how-does-enumerate-trainloader-0-work/14410 I'm trying to understand how training epochs behave in PyTorch.

Take this outer and inner loop :

for epoch in range(num_epochs):
    for i1,i2 in enumerate(training_loader):

Is this a correct interpretation : ?

For each invocation of the outer loop/epoch the entire training set, in above example training_loader is iterated per batch. This means the model does not process one instance per training cycle. Per training cycle ( for epoch in range(num_epochs): ) the entire training set is processed in chunks/batches where the batch size is determined when creating training_loader

Upvotes: 0

Views: 1556

Answers (1)

Haran Rajkumar
Haran Rajkumar

Reputation: 2395

torch.utils.data.DataLoader returns an iterable that iterates over the dataset.

Therefore, the following -

training_loader = torch.utils.data.DataLoader(*args)
for i1,i2 in enumerate(training_loader):

  #process

runs one over the dataset completely in batches.

Upvotes: 0

Related Questions