Raven Cheuk
Raven Cheuk

Reputation: 3053

Loading custom dataset in pytorch

Normally, when we are loading data in pytorch, we do the followings

for x, y in dataloaders:
    # Do something

However, in this dataset called MusicNet, they declare their own dataset and dataloader like this

train_set = musicnet.MusicNet(root=root, train=True, download=True, window=window)#, pitch_shift=5, jitter=.1)
test_set = musicnet.MusicNet(root=root, train=False, window=window, epoch_size=50000)

train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**kwargs)

Then they load the data like this

with train_set, test_set:
    for i, (x, y) in enumerate(train_loader):
        # Do something

Question 1

I don't understand why the code doesn't work without the line with train_set, test_set.

Question 2

Also, how do I access the data?

I tried

train_set.access(2560,0)

and

with train_set, test_set:
    x, y = train_set.access(2560,0)

They either give me an error message like

KeyError Traceback (most recent call last) in ----> 1 train_set.access(2560,0)

/workspace/raven_data/AMT/MusicNet/pytorch_musicnet/musicnet.py in access(self, rec_id, s, shift, jitter) 106 107 if self.mmap: --> 108 x = np.frombuffer(self.records[rec_id][0][ssz_float:int(s+scaleself.window)*sz_float], dtype=np.float32).copy() 109 else: 110 fid,_ = self.records[rec_id]

KeyError: 2560

or giving me an empty x and y

Upvotes: 1

Views: 789

Answers (1)

ndrwnaguib
ndrwnaguib

Reputation: 6135

Question 1

I don't understand why the code doesn't work without the line with train_set, test_set.

For you to be able to use the torch.utils.data.DataLoader with a custom dataset design, you must create a class of your dataset which subclasses torch.utils.data.Dataset (and implementing specific functions) and pass it to the dataloader, even they say so:

All other datasets should subclass it. All subclasses should override __len__, that provides the size of the dataset, and __getitem__, supporting integer indexing in range from 0 to len(self) exclusive.

This is what happens in:

train_set = musicnet.MusicNet(root=root, train=True, download=True, window=window)#, pitch_shift=5, jitter=.1)

test_set = musicnet.MusicNet(root=root, train=False, window=window, epoch_size=50000)

train_loader = torch.utils.data.DataLoader(dataset=train_set,batch_size=batch_size,**kwargs)
test_loader = torch.utils.data.DataLoader(dataset=test_set,batch_size=batch_size,**k

If you check their musicnet.MusicNet, you will find that they do so.

Question 2

Also, how do I access the data?

There are possible ways:

To get only a batch from the dataset, you can do:

batch = next(iter(train_loader))

To access the whole dataset (especially, in your example):

dataset = train_loader.dataset.records

(The .records is the part which may vary from dataset to another, I said .records because this is what I found in here)

Upvotes: 1

Related Questions