edoost
edoost

Reputation: 101

How can I save PyTorch's DataLoader instance?

I want to save PyTorch's torch.utils.data.dataloader.DataLoader instance, so that I can continue training where I left off (keeping shuffle seed, states and everything).

Upvotes: 4

Views: 11223

Answers (3)

b-fg
b-fg

Reputation: 4137

The native PyTorch support for this is still not available, but considered for future improvements. Still, see other answers for custom builds.

Upvotes: 1

edoost
edoost

Reputation: 101

It's quite simple. One should design their own Sampler which takes the starting index and shuffles the data by itself:

import random
from torch.utils.data.dataloader import Sampler


random.seed(224)  # use a fixed number


class MySampler(Sampler):
    def __init__(self, data, i=0):
        random.shuffle(data)
        self.seq = list(range(len(data)))[i * batch_size:]

    def __iter__(self):
        return iter(self.seq)

    def __len__(self):
        return len(self.seq)

Now save the last index i somewhere and the next time instantiate the DataLoader using it:

train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,                                                         
                               batch_size=batch_size, 
                               sampler=train_sampler,
                               shuffle=False)  # don't forget to set DataLoader's shuffle to False

It's quite useful when training on Colab.

Upvotes: 3

usamec
usamec

Reputation: 2394

You need a custom implementation of the sampler. Something hassle-free can be used from: https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5

You can save and resume like:

sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)

for x in loader:
    print(x)
    break

sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)

for x in loader2:
    print(x)

Upvotes: 5

Related Questions