hanugm
hanugm

Reputation: 1417

How can I access the next step data using DataLoader in PyTorch?

I am using a code that trains neural networks. The code uses the DataLoader of PyTorch to load the data for every iteration. The code looks as follows

for step, data in enumerate(dataloader, 0):
      ............................................................
      output = neuralnetwork_model(data)
      .............................................................

Here the step is an integer that gives values 0, 1, 2, 3, ....... and data gives a batch of samples at each step. The code passes corresponding batches to the neural network at each step.

I need to just access the data of step n+1 at step n. I need something like this

for step, data in enumerate(dataloader, 0):
      ............................................................
      output = neuralnetwork_model(data)
      access = data_of_next_step
      .............................................................

How can I achieve this?

Upvotes: 1

Views: 582

Answers (1)

Ivan
Ivan

Reputation: 40748

It seems to be handier to perform such manipulation at the iteration level rather than having to change the data loaders implementation. Looking at Iterate over n successive elements with overlap you can achieve this using itertools.tee:

def pairwise(iterable):
    "s -> (s0,s1), (s1,s2), (s2, s3), ..."
    a, b = tee(iterable)
    next(b, None)
    return zip(a, b)

Therefore you simply have to iterate over your wrapped data loader with:

>>> for batch1, batch2 pairwise(dataloader)
...     # batch1 is current batch
...     # batch2 is batch of following step

Upvotes: 1

Related Questions