Reputation: 101
I want to save PyTorch's torch.utils.data.dataloader.DataLoader
instance, so that I can continue training where I left off (keeping shuffle seed, states and everything).
Upvotes: 4
Views: 11223
Reputation: 4137
The native PyTorch support for this is still not available, but considered for future improvements. Still, see other answers for custom builds.
Upvotes: 1
Reputation: 101
It's quite simple. One should design their own Sampler
which takes the starting index and shuffles the data by itself:
import random
from torch.utils.data.dataloader import Sampler
random.seed(224) # use a fixed number
class MySampler(Sampler):
def __init__(self, data, i=0):
random.shuffle(data)
self.seq = list(range(len(data)))[i * batch_size:]
def __iter__(self):
return iter(self.seq)
def __len__(self):
return len(self.seq)
Now save the last index i
somewhere and the next time instantiate the DataLoader
using it:
train_dataset = MyDataset(train_data)
train_sampler = MySampler(train_dataset, last_i)
train_data_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
sampler=train_sampler,
shuffle=False) # don't forget to set DataLoader's shuffle to False
It's quite useful when training on Colab.
Upvotes: 3
Reputation: 2394
You need a custom implementation of the sampler. Something hassle-free can be used from: https://gist.github.com/usamec/1b3b4dcbafad2d58faa71a9633eea6a5
You can save and resume like:
sampler = ResumableRandomSampler(dataset)
loader = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler, pin_memory=True)
for x in loader:
print(x)
break
sampler2 = ResumableRandomSampler(dataset)
torch.save(sampler.get_state(), "test_samp.pth")
sampler2.set_state(torch.load("test_samp.pth"))
loader2 = torch.utils.data.DataLoader(dataset, batch_size=2, sampler=sampler2, pin_memory=True)
for x in loader2:
print(x)
Upvotes: 5