Panertoĸ
Panertoĸ

Reputation: 33

How can I create images for each batch using Pytorch?

I want to make a binary classifier that classifies the following:

Class 1. Some images that I already have.

Class 2. Some images that I create from a function, using the images of class 1.

The problem is that instead of pre-creating the two classes, and then loading them, to speed up the process I would like the class 2 images to be created for each batch.

Any ideas on how I can tackle the problem? If I use the DataLoader as usual, I have to enter the images of both classes directly, but if I still don't have the images of the second class I don't know how to do it.

Thanks.

Upvotes: 0

Views: 304

Answers (1)

aretor
aretor

Reputation: 2569

You can tackle the problem in at least two ways.

  1. (Preferred) You create a custom Dataset class, AugDset, such that AugDset.__len__() returns 2 * len(real_dset), and when idx > len(imgset), AugDset.__getitem__(idx) generates the synthetic image from real_dset(idx).
  2. You create your custom collate_fn function, to be passed to DataLoader that, given a batch, it augments it with your synthetic generated images.

Upvotes: 1

Related Questions