Ichi Ban
Ichi Ban

Reputation: 53

Pytorch - Custom Dataset out of Range

im new to Pytorch and DL in genereal, so i hope this is the right place to ask questions like this.

I wanted to create my first Dataset, but my Dataset always runs out of bound. This problem should be easiest to show with the codes and outputs.

class DataframeDataset(torch.utils.data.Dataset):
    """Load Pytorch Dataset from Dataframe
    
    """

    def __init__(self, data_frame, input_key, target_key, transform=None, features=None):
        self.data_frame = data_frame
        self.input_key = input_key
        self.target_key = target_key
        self.inputs = self.data_frame[input_key]
        self.targets = self.data_frame[target_key]
        self.transform = transform
        self.features = [input_key, target_key] if features is None else features
        self.len = len(self.inputs)

    def __len__(self):
        return self.len

    def __str__(self):
        return str(self.info())

    def info(self):
        info = {
            'features': self.features,
            'num_rows': len(self)
        }
        return "Dataset("+ str(info) + ")"

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        data = {
            self.input_key: self.inputs[idx],
            self.target_key: self.targets[idx]
        }

        if self.transform:
            return self.transform(data)

        return data

    @staticmethod
    def collate_fn(input_key, output_key):
        def __call__(batch):
            speeches = [data[input_key] for data in batch]
            sentences = [data[output_key] for data in batch]
            return speeches, sentences

        return __call__

with some mook data:

    data = [("x1", "y2", "A3"), ("x1", "y2", "b3"), ("x1", "y2", "c3"), ("x1", "y2", "d3")]
    df = pd.DataFrame(data, columns=['input', 'target', 'random'])
    print(df.head())
  input target random
0    x1     y2     A3
1    x1     y2     b3
2    x1     y2     c3
3    x1     y2     d3
    ds = DataframeDataset(data_frame=df, input_key="input", target_key="target", transform=None)
    print("Len:", len(ds))
    print("Ds", ds)
    print(ds[0])
Len: 4
Ds Dataset({'features': ['input', 'target'], 'num_rows': 4})
{'input': 'x1', 'target': 'y2'}

So the basic functions seem to work. However, if I want to iterate over the data with a foreach loop, then unfortunately the loop does not know the boundaries. So I get a key-error, because the torch accesses indicies outside the boundary.

    for idx, data in enumerate(ds):
        print(idx,"->",data)
0 -> {'input': 'x1', 'target': 'y2'}
1 -> {'input': 'x1', 'target': 'y2'}
2 -> {'input': 'x1', 'target': 'y2'}
3 -> {'input': 'x1', 'target': 'y2'}
Traceback (most recent call last):
  File "/home/warmachine/.local/lib/python3.8/site-packages/pandas/core/indexes/range.py", line 351, in get_loc
    return self._range.index(new_key)
ValueError: 4 is not in range`

If i do something like

    for idx in range(0, len(ds)):
        data = ds[idx]
        print(idx, "->", data)

it works, but i need to be able to use the for-each Style, so that i can use this Dataset within the Trainer of Hugging Face.

Ty in advcanded

Upvotes: 2

Views: 325

Answers (1)

Barney Stinson
Barney Stinson

Reputation: 1002

If you want to use Foreach loops, you must implement an Iterator function. Here is an example from PyTorch:

https://pytorch.org/docs/stable/data.html#torch.utils.data.IterableDataset

Slightly modified, works for me.

class DataframeDataset(torch.utils.data.Dataset):
...
    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()
        if worker_info is None:
            return map(self.__getitem__, range(self.__len__()))

        per_worker = int(math.ceil((self.__len__()) / float(worker_info.num_workers)))
        worker_id = worker_info.id
        iter_start = worker_id * per_worker
        iter_end = min(iter_start + per_worker, self.__len__())
        return map(self.__getitem__, range(iter_start, iter_end))

Upvotes: 1

Related Questions