Jake
Jake

Reputation: 65

Pytorch transformation for just certain batch

Hi is there any method for apply trasnformation for certain batch?

It means, I want apply trasnformation for just last batch in every epochs.

What I tried is here

import torch


class test(torch.utils.data.Dataset):
    def __init__(self):
        self.source = [i for i in range(10)]

    def __len__(self):
        return len(self.source)
        
    def __getitem__(self, idx):   
        print(idx)
        return self.source[idx]

ds = test()
dl = torch.utils.data.DataLoader(dataset = ds, batch_size = 3,
                                shuffle = False, num_workers = 5)

for i in dl:
    print(i)

because I thought that if I could get idx number, it would be possible to apply for certain batchs.

However If using num_workers outputs are

0
1
2
3
964


57

8
tensor([0, 1, 2])
tensor([3, 4, 5])
tensor([6, 7, 8])
tensor([9])

which are not I thought

without num_worker

0
1
2
tensor([0, 1, 2])
3
4
5
tensor([3, 4, 5])
6
7
8
tensor([6, 7, 8])
9
tensor([9])

So the question is

  1. Why idx works so with num_workers?
  2. How can I apply trasnform for certain batchs (or certain idx)?

Upvotes: 1

Views: 1143

Answers (2)

Jake
Jake

Reputation: 65

I found that

class test_dataset(torch.utils.data.Dataset):
    def __init__(self):
        self.a = [i for i in range(100)]
    def __len__(self):
        return len(self.a)
    
    def __getitem__(self, idx):
        a = torch.tensor(self.a[idx])
        #print(idx)
        return idx

a = torch.utils.data.DataLoader(
        test_dataset(), batch_size = 10, shuffle = False,
        num_workers = 10, pin_memory = True)

for i in a:
    print(i)


tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([10, 11, 12, 13, 14, 15, 16, 17, 18, 19])
tensor([20, 21, 22, 23, 24, 25, 26, 27, 28, 29])
tensor([30, 31, 32, 33, 34, 35, 36, 37, 38, 39])
tensor([40, 41, 42, 43, 44, 45, 46, 47, 48, 49])
tensor([50, 51, 52, 53, 54, 55, 56, 57, 58, 59])
tensor([60, 61, 62, 63, 64, 65, 66, 67, 68, 69])
tensor([70, 71, 72, 73, 74, 75, 76, 77, 78, 79])
tensor([80, 81, 82, 83, 84, 85, 86, 87, 88, 89])
tensor([90, 91, 92, 93, 94, 95, 96, 97, 98, 99])

Upvotes: 0

GoodDeeds
GoodDeeds

Reputation: 8497

  1. When you have num_workers > 1, you have multiple subprocesses doing data loading in parallel. So what is likely happening is that there is a race condition for the print step, and the order you see in the output depends on which subprocess goes first each time.

  2. For most transforms, you can apply them on a specific batch simply by calling the transform after the batch has been loaded. To do this just for the last batch, you could do something like:

    for batch_idx, batch_data in dl:
        # check if batch is the last batch
        if ((batch_idx+1) * batch_size) >= len(ds):
            batch_data = transform(batch_data)
    

Upvotes: 1

Related Questions