a_jelly_fish
a_jelly_fish

Reputation: 480

How to create a Pytorch Dataset from .pt files?

I have transformed MNIST images saved as .pt files in a folder in Google drive. I'm writing my Pytorch code in Colab.

I would like to use these files, and create a Dataset that stores these images as Tensors. How can I do this?

Transforming images during training took too long. Hence, transformed them and saved them all as .pt files. I just want to load them back as a dataset and use them in my model.

Upvotes: 1

Views: 5297

Answers (1)

Wasi Ahmad
Wasi Ahmad

Reputation: 37691

The approach you are following to save images is indeed a good idea. In such a case, you can simply write your own Dataset class to load the images.

from torch.utils.data import Dataset, DataLoader
from torch.utils.data.sampler import RandomSampler

class ReaderDataset(Dataset):
    def __init__(self, filename):
        # load the images from file

    def __len__(self):
        # return total dataset size

    def __getitem__(self, index):
        # write your code to return each batch element

Then you can create Dataloader as follows.

train_dataset = ReaderDataset(filepath)
train_sampler = RandomSampler(train_dataset)
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    sampler=train_sampler,
    num_workers=args.data_workers,
    collate_fn=batchify,
    pin_memory=args.cuda,
    drop_last=args.parallel
)
# args is a dictionary containing parameters
# batchify is a custom function that prepares each mini-batch

Upvotes: 2

Related Questions