Reputation: 59
I'm searching to create a personnal dataloader with a specific format to use Pytorch library, someone have an idea how can I do it ? I have follow Pytorch Tutorial but I don't find my answer!
I need a DataLoader that yields the tuples of the following format: (Bx3xHxW FloatTensor x, BxHxW LongTensor y, BxN LongTensor y_cls) where x - batch of input images, y - batch of groung truth seg maps, y_cls - batch of 1D tensors of dimensionality N: N total number of classes, y_cls[i, T] = 1 if class T is present in image i, 0 otherwise
I hope that someone can unlock the problem .. :) Thanks !
Upvotes: 1
Views: 1606
Reputation: 114926
You simply need to have a database derived from torch.utils.data.Dataset
, where __getitem__(index)
returns a tuple (x, y, y_cls)
of the types you want, pytorch will take care of everything else.
from torch.utils import data
class MyTupleDataset(data.Dataset):
def __init__(self):
super(MyTupleDataset, self).__init__()
# init your dataset here...
def __getitem__(index):
x = torch.Tensor(3, H, W) # batch dim is handled by the data loader
y = torch.Tensor(H, W).to(torch.long)
y_cls = torch.Tensor(N).to(torch.long)
return x, y, y_cls
That's it. Provide pytorch's torch.utils.data.DataLoader
with MyTupleDataset
and you are done.
Upvotes: 1