MUAS
MUAS

Reputation: 626

Adding class objects to Pytorch Dataloader: batch must contain tensors

I have a custom Pytorch dataset that returns a dictionary containing a class object "queries".

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return self.values.shape[0]

    def __getitem__(self, idx):
        sample = DeviceDict({'query': self.queries[idx],
                             "values": self.values[idx],
                             "targets": self.targets[idx]})
        return sample

The problem is that when I put the queries in a data loader I get default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'query.Query'>. Is there a way to have a class object in my data loader? It blows up at next(iterator) in the code below.

train_queries = QueryDataset(train_queries)
train_loader = torch.utils.data.DataLoader(train_queries,
                                           batch_size=10],
                                           shuffle=True,
                                           drop_last=False)
for i in range(epochs):
    iterator = iter(train_loader)
    for i in range(len(train_loader)):
        batch = next(iterator)
        out = model(batch)
        loss = criterion(out["pred"], batch["targets"])
        self.optimizer.zero_grad()
        loss.sum().backward()
        self.optimizer.step()

Upvotes: 4

Views: 8217

Answers (2)

MUAS
MUAS

Reputation: 626

For those curious, this is the DeviceDict and custom collate function that I used to get things to work.

class DeviceDict(dict):

    def __init__(self, *args):
        super(DeviceDict, self).__init__(*args)

    def to(self, device):
        dd = DeviceDict()
        for k, v in self.items():
            if torch.is_tensor(v):
                dd[k] = v.to(device)
            else:
                dd[k] = v
        return dd


def collate_helper(elems, key):
    if key == "query":
        return elems
    else:
        return torch.utils.data.dataloader.default_collate(elems)


def custom_collate(batch):
    elem = batch[0]
    return DeviceDict({key: collate_helper([d[key] for d in batch], key) for key in elem})

Upvotes: 4

Hossein
Hossein

Reputation: 25924

You need to define your own colate_fn in order to do this. A sloppy approach just to show you how stuff works here, would be something like this:

import torch
class DeviceDict:
    def __init__(self, data):
        self.data = data 

    def print_data(self):
        print(self.data)

class QueryDataset(torch.utils.data.Dataset):

    def __init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def __len__(self):
        return 5

    def __getitem__(self, idx):
        sample = {'query': self.queries[idx],
                 "values": self.values[idx],
                 "targets": self.targets[idx]}
        return sample

def custom_collate(dict):
    return DeviceDict(dict)

dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()

Basically colate_fn allows you to achieve custom batching or adding support for custom data types as explained in the link I previously provided.
As you see it just shows the concept, you need to change it based on your own needs.

Upvotes: 6

Related Questions