Pythmania
Pythmania

Reputation: 59

Make personnal Dataloader with PYTORCH

I'm searching to create a personnal dataloader with a specific format to use Pytorch library, someone have an idea how can I do it ? I have follow Pytorch Tutorial but I don't find my answer!

I need a DataLoader that yields the tuples of the following format: (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where x - batch of input images, y - batch of groung truth seg maps, y_cls - batch of 1D tensors of dimensionality N: N total number of classes, y_cls[i, T] = 1 if class T is present in image i, 0 otherwise

I hope that someone can unlock the problem .. :) Thanks !

Upvotes: 1

Views: 1606

Answers (1)

Shai
Shai

Reputation: 114926

You simply need to have a database derived from torch.utils.data.Dataset, where __getitem__(index) returns a tuple (x, y, y_cls) of the types you want, pytorch will take care of everything else.

from torch.utils import data

class MyTupleDataset(data.Dataset):
  def __init__(self):
    super(MyTupleDataset, self).__init__()
    # init your dataset here...

  def __getitem__(index):
    x = torch.Tensor(3, H, W)  # batch dim is handled by the data loader
    y = torch.Tensor(H, W).to(torch.long)
    y_cls = torch.Tensor(N).to(torch.long)
    return x, y, y_cls

That's it. Provide pytorch's torch.utils.data.DataLoader with MyTupleDataset and you are done.

Upvotes: 1

Related Questions