Thoth
Thoth

Reputation: 1041

Efficiently sample batches from only one class at each iteration with PyTorch

I want to train a classifier on ImageNet dataset (1000 classes) and I need each batch to contain 64 images from the same class and consecutive batches from different classes. So far based on @shai's suggestion and this post I have

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os


class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data

        self.indices = [[] for _ in range(num_classes)]
        for i, (data, class_label) in enumerate(data):
            # create a list of lists, where every sublist containts the indices of
            # the samples that belong to the class_label
            self.indices[class_label].append(i)

    def classes(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]


class BatchSampler:
    def __init__(self, classes, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.classes = classes
        self.n_batches = sum([len(x) for x in classes]) // batch_size
        self.min_class_size = min([len(x) for x in classes])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.classes)))
        random.shuffle(self.class_range)

        assert batch_size < self.min_class_size, 'batch_size should be at least {}'.format(self.min_class_size)

    def __iter__(self):
        batches = []
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            batches.append(np.random.choice(self.classes[batch_class], self.batch_size))
        return iter(batches)


def main():
    # Code about
    _train_dataset = DS(train_dataset, train_dataset.num_classes)
    _batch_sampler = BatchSampler(_train_dataset.classes(), batch_size=args.batch_size)
    _train_loader = DataLoader(dataset=_train_dataset, batch_sampler=_batch_sampler)
    labels = []
    for i, (inputs, _labels) in enumerate(_train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))

    # Samplers are initialized to None and train_sampler will be replaced
    train_sampler, val_sampler = None, None
    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size, shuffle=(train_sampler is None),
        num_workers=args.workers, pin_memory=True, sampler=train_sampler)

    val_loader = torch.utils.data.DataLoader(
        val_dataset, batch_size=args.batch_size, shuffle=False,
        num_workers=args.workers, pin_memory=True, sampler=val_sampler)

    main()

which prints: Length of traversed unique labels: 100.

However, creating self.indices in the for loop takes a lot of time. Is there a more efficient way to construct this sampler?

EDIT: yield implementation

import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import numpy as np
import random
import argparse
import torch
import os
from tqdm import tqdm
import os.path


class DS(Dataset):
    def __init__(self, data, num_classes):
        super(DS, self).__init__()
        self.data = data
        self.data_len = len(data)

        indices = [[] for _ in range(num_classes)]

        for i, (_, class_label) in tqdm(enumerate(data), total=len(data), miniters=1,
                                        desc='Building class indices dataset..'):
            indices[class_label].append(i)

        self.indices = indices

    def per_class_sample_indices(self):
        return self.indices

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.data_len


class BatchSampler:
    def __init__(self, per_class_sample_indices, batch_size):
        # classes is a list of lists where each sublist refers to a class and contains
        # the sample ids that belond to this class
        self.per_class_sample_indices = per_class_sample_indices
        self.n_batches = sum([len(x) for x in per_class_sample_indices]) // batch_size
        self.min_class_size = min([len(x) for x in per_class_sample_indices])
        self.batch_size = batch_size
        self.class_range = list(range(len(self.per_class_sample_indices)))
        random.shuffle(self.class_range)

    def __iter__(self):
        for j in range(self.n_batches):
            if j < len(self.class_range):
                batch_class = self.class_range[j]
            else:
                batch_class = random.choice(self.class_range)
            if self.batch_size <= len(self.per_class_sample_indices[batch_class]):
                batch = np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size)
                # batches.append(np.random.choice(self.per_class_sample_indices[batch_class], self.batch_size))
            else:
                batch = self.per_class_sample_indices[batch_class]
            yield batch

    def n_batches(self):
        return self.n_batches


def main():
    file_path = 'a_file_path'
    file_name = 'per_class_sample_indices.pt'
    if not os.path.exists(os.path.join(file_path, file_name)):
        print('File: {} does not exists. Create it.'.format(file_name))
        per_class_sample_indices = DS(train_dataset, num_classes).per_class_sample_indices()
        torch.save(per_class_sample_indices, os.path.join(file_path, file_name))
    else:
        per_class_sample_indices = torch.load(os.path.join(file_path, file_name))
        print('File: {} exists. Do not create it.'.format(file_name))

    batch_sampler = BatchSampler(per_class_sample_indices,
                                 batch_size=args.batch_size)

    train_loader = torch.utils.data.DataLoader(
        train_dataset,
        # batch_size=args.batch_size,
        # shuffle=(train_sampler is None),
        num_workers=args.workers,
        pin_memory=True,
        # sampler=train_sampler,
        batch_sampler=batch_sampler
    )

    # We do not use sampler for the validation
    # val_loader = torch.utils.data.DataLoader(
    #     val_dataset, batch_size=args.batch_size, shuffle=False,
    #     num_workers=args.workers, pin_memory=True, sampler=None)

    labels = []
    for i, (inputs, _labels) in enumerate(train_loader):
        labels.append(torch.unique(_labels).item())
        print("Unique labels: {}".format(torch.unique(_labels).item()))

    labels = set(labels)
    print('Length of traversed unique labels: {}'.format(len(labels)))


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='PyTorch ImageNet Training')

    parser.add_argument('--data', metavar='DIR', nargs='?', default='imagenet',
                        help='path to dataset (default: imagenet)')
    parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")

    parser.add_argument('-b', '--batch-size', default=64, type=int,
                        metavar='N',
                        help='mini-batch size (default: 256), this is the total '
                             'batch size of all GPUs on the current node when '
                             'using Data Parallel or Distributed Data Parallel')

    parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
                        help='number of data loading workers (default: 4)')

    args = parser.parse_args()

    if args.dummy:
        print("=> Dummy data is used!")
        num_classes = 100
        train_dataset = datasets.FakeData(size=12811, image_size=(3, 224, 224),
                                          num_classes=num_classes, transform=transforms.ToTensor())
        val_dataset = datasets.FakeData(5000, (3, 224, 224), num_classes, transforms.ToTensor())
    else:
        traindir = os.path.join(args.data, 'train')
        valdir = os.path.join(args.data, 'val')

        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        val_dataset = datasets.ImageFolder(
            valdir,
            transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ]))
        num_classes = len(train_dataset.classes)

    main()

A similar post but in TensorFlow can be found here

Upvotes: 5

Views: 1917

Answers (2)

Ivan
Ivan

Reputation: 40618

Your code seems fine. The issue here is not the sampler but the preprocessing step you are required to perform in order to sort out the instance indices by their class. Since this is always the same sort, I recommend you cache this information (the data contained inside of self.indices) on your file system such that you avoid having to reconstruct it on every dataset load. You can do so using either numpy.save or torch.save.

Upvotes: 3

Shai
Shai

Reputation: 114786

You should write your own batch_sampler class for the DataLoader.

Upvotes: 2

Related Questions