Reputation: 2489
I have a need to use a BatchSampler
within a pytorch DataLoader
instead of calling __getitem__
of the dataset multiple times (remote dataset, each query is pricy).
I cannot understand how to use the batchsampler with any given dataset.
e.g
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, idx):
return self.ddf[idx] --------> This is as expensive as a batch call
def get_batch(self, batch_idx):
return self.ddf[batch_idx]
my_loader = DataLoader(MyDataset(remote_ddf),
batch_sampler=BatchSampler(Sampler(), batch_size=3))
The thing I do not understand, neither found any example online or in torch docs, is how do I use my get_batch
function instead of the __getitem__ function.
Edit:
Following the answer of Szymon Maszke, this is what I tried and yet, \_\_get_item__
gets one index each call, instead of a list of size batch_size
class Dataset(Dataset):
def __init__(self):
...
def __len__(self):
...
def __getitem__(self, batch_idx): ------> here I get only one index
return self.wiki_df.loc[batch_idx]
loader = DataLoader(
dataset=dataset,
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False),
num_workers=self.hparams.num_data_workers,
)
Upvotes: 9
Views: 27188
Reputation: 24681
You can't use get_batch
instead of __getitem__
and I don't see a point to do it like that.
torch.utils.data.BatchSampler
takes indices from your Sampler()
instance (in this case 3
of them) and returns it as list
so those can be used in your MyDataset
__getitem__
method (check source code, most of samplers and data-related utilities are easy to follow in case you need it).
I assume your self.ddf
supports list slicing (e.g. self.ddf[[25, 44, 115]]
returns values correctly and uses only one expensive call). In this case simply switch get_batch
into __getitem__
and you are good to go.
class MyDataset(Dataset):
def __init__(self, remote_ddf, ):
self.ddf = remote_ddf
def __len__(self):
return len(self.ddf)
def __getitem__(self, batch_idx):
return self.ddf[batch_idx] -> batch_idx is a list
EDIT: You have to specify batch_sampler
as sampler
, otherwise the batch will be divided into single indices. This should be fine:
loader = DataLoader(
dataset=dataset,
# This line below!
batch_sampler=BatchSampler(
SequentialSampler(dataset), batch_size=self.hparams.batch_size, drop_last=False
),
num_workers=self.hparams.num_data_workers,
)
Upvotes: 7