Reputation: 65
Hi is there any method for apply trasnformation for certain batch?
It means, I want apply trasnformation for just last batch in every epochs.
What I tried is here
import torch
class test(torch.utils.data.Dataset):
def __init__(self):
self.source = [i for i in range(10)]
def __len__(self):
return len(self.source)
def __getitem__(self, idx):
print(idx)
return self.source[idx]
ds = test()
dl = torch.utils.data.DataLoader(dataset = ds, batch_size = 3,
shuffle = False, num_workers = 5)
for i in dl:
print(i)
because I thought that if I could get idx number, it would be possible to apply for certain batchs.
However If using num_workers outputs are
0
1
2
3
964
57
8
tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])
which are not I thought
without num_worker
0
1
2
tensor([0, 1, 2])
3
4
5
tensor([3, 4, 5])
6
7
8
tensor([6, 7, 8])
9
tensor([9])
So the question is
Upvotes: 1
Views: 1143
Reputation: 65
I found that
class test_dataset(torch.utils.data.Dataset):
def __init__(self):
self.a = [i for i in range(100)]
def __len__(self):
return len(self.a)
def __getitem__(self, idx):
a = torch.tensor(self.a[idx])
#print(idx)
return idx
a = torch.utils.data.DataLoader(
test_dataset(), batch_size = 10, shuffle = False,
num_workers = 10, pin_memory = True)
for i in a:
print(i)
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])
tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])
tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])
tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])
Upvotes: 0
Reputation: 8497
When you have num_workers
> 1, you have multiple subprocesses doing data loading in parallel. So what is likely happening is that there is a race condition for the print step, and the order you see in the output depends on which subprocess goes first each time.
For most transforms, you can apply them on a specific batch simply by calling the transform after the batch has been loaded. To do this just for the last batch, you could do something like:
for batch_idx, batch_data in dl:
# check if batch is the last batch
if ((batch_idx+1) * batch_size) >= len(ds):
batch_data = transform(batch_data)
Upvotes: 1