Ashar
Ashar

Reputation: 794

Python: Generate a unique batch from given dataset

I'm applying a CNN to classify a given dataset.

My function:

def batch_generator(dataset, input_shape = (256, 256), batch_size = 32):
    dataset_images = []
    dataset_labels = []
    for i in range(0, len(dataset)):
        dataset_images.append(cv2.resize(cv2.imread(dataset[i], cv2.IMREAD_COLOR), 
                     input_shape, interpolation = cv2.INTER_AREA))
        dataset_labels.append(labels[dataset[i].split('/')[-2]])
    return dataset_images, dataset_labels

This function is supposed to be called for every epoch and it should return a unique batch of size 'batch_size' containing dataset_images (each image is 256x256) and corresponding dataset_label from the labels dictionary.

input 'dataset' contains path to all the images, so I'm opening them and resizing them to 256x256. Can someone help me in adding to this code so that is returns the desired batches?

Upvotes: 3

Views: 1906

Answers (2)

jodag
jodag

Reputation: 22294

PyTorch has two similar sounding, but very different abstractions for loading data. I strongly recommend reading the documentation on dataloaders here. To summarize

  1. A Dataset is an object you generally implement that returns an individual sample (data + label)
  2. A DataLoader is a built-in class in pytorch that samples batches of samples from a dataset (potentially in parallel).

A (map-style) Dataset is a simple object that just implements two mandatory methods: __getitem__ and __len__. Getitem is the method that is invoked on an object when you use the square-bracket operator i.e. dataset[i] and __len__ is the method that is invoked when you use the python built-in len function on your object, i.e. len(dataset)

For pytorch you usually want __getitem__ to return a tuple containing both the data and the label for a single item in your dataset object. For example based on what you provided, something like this should suit your needs

from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.functional as F

class CustomDataset(Dataset):
    def __init__(self, image_paths, labels, input_shape=(256, 256)):
        # `image_paths` is what you called `dataset` in your example.
        #               I'm assume this is a list of image paths.
        # `labels` isn't defined in your script but I assume its a
        #          dict that maps image names to an integer label
        #          between 0 and num classes minus 1
        self.image_paths = image_paths
        self.labels = labels
        self.input_shape = input_shape

    def __getitem__(self, index):
        # return the data and label for the specified index
        image_path = self.image_paths[index]
        data = cv2.resize(cv2.imread(image_path, cv2.IMREAD_COLOR), 
                          self.input_shape, interpolation = cv2.INTER_AREA)
        label = self.labels[image_path.split('/')[-2]]

        # convert data to PyTorch tensor
        # This converts data from a uint8 np.array of shape HxWxC
        # between 0 and 255 to a pytorch float32 tensor of shape CxHxW
        # between 0.0 and 1.0.
        data = F.to_tensor(data)

        return data, label

    def __len__(self):
        return len(self.image_paths)

...
# using what you call "dataset" and "labels"
# num_workers > 0 allows you to load data in parallel while network is running
dataloader = DataLoader(
    CustomDataset(dataset, labels, (256, 256)),
    batch_size=32,
    shuffle=True,    # shuffle tells us to randomly sample the
                     # dataset without replacement
    num_workers=4    # num workers is the number of worker processes
                     # that load from dataset in parallel while your
                     # model is processing stuff
)

# training loop
for epoch in range(num_epochs):
    # iterates over all data in your dataset in a random order
    # in batches of size 32 each time this loop is run
    for data_batch, label_batch in dataloader:
        # data_batch is a pytorch FloatTensor of shape 32x3x256x256
        # label_batch is a pytorch LongTensor of shape 32

        # if using GPU acceleration now is the time to move data_batch and label_batch to GPU
        # data_batch = data_batch.cuda()
        # label_batch = label_batch.cuda()

        # zero the gradients, pass data through your model, backprop, and step the optimizer
        ...

Upvotes: 0

CutePoison
CutePoison

Reputation: 5385

As @jodag suggests, using DataLoaders is a good idea.

I have a snippet of that I use for some of my CNN in Pytorch

from torch.utils.data import Dataset, DataLoader
import torch
class Data(Dataset):
    """
    Constructs a Dataset to be parsed into a DataLoader
    """
    def __init__(self,X,y):
        X = torch.from_numpy(X).float()

        #Transpose to fit dimensions of my network
        X = torch.transpose(X,1,2)

        y = torch.from_numpy(y).float()
        self.X,self.y = X,y

    def __getitem__(self, i):
        return self.X[i],self.y[i]

    def __len__(self):
        return self.X.shape[0]

def create_data_loader(X,y,batch_size,**kwargs):
    """
    Creates a data-loader for the data X and y

    params:
    -------

    X: np.array
        - numpy array of size "n" x k where n is samples an "k" is number of features

    y: np.array
        - numpy array of sie "n"

    batch_size: int
        - Take a wild guess, dumbass

    kwargs:
        - Additional keyword-arguments for "DataLoader"

    return
    ------

    dl: torch.utils.data.DataLoader object
    """

    data = Data(X, y)

    dl = DataLoader(data, batch_size=batch_size,num_workers=0,**kwargs)
    return dl

which is used like this;

from create_data_loader import create_data_loader

train_data_loader= create_data_loader(X_train,y_train,batch_size=32) #Note, it has "shuffle=True" as default!
val_data_loader= create_data_loader(X_val,y_val,batch_size=32,shuffle=False) #If you want to keep index'es in the same order for e.g cross-validate


for x_train, y_train in train_data_loader:
   logit = net(x_train,y_train)
   .
   .
   net.eval()
   for x_val,y_val in val_data_loader:
       logit  = net(x_val,y_val)
       classes_pred = logit.argmax(axis=1)
       print(f"Val accuracy: {(y_val==classes_pred).mean()}")

Upvotes: 1

Related Questions