TFC
TFC

Reputation: 546

PyTorch : How to apply the same random transformation to multiple image?

I am writing a simple transformation for a dataset which contains many pairs of images. As a data augmentation, I want to apply some random transformation for each pair but the images in that pair should be transformed in the same way. For example, given a pair of two images A and B, if A is flipped horizontally, B must be flipped horizontally as A. Then the next pair C and D should be differently transformed from A and B but C and D are transformed in the same way. I am trying that in the way below

import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

transform = transforms.RandomChoice(
    [transforms.RandomHorizontalFlip(), 
     transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))

random.seed(1)
display(transform(img_c))
display(transform(img_d))

Yet、 the above code does not choose the same transformation and as I tested, it is dependent on the number of times transform is called.

Is there any way to force transforms.RandomChoice to use the same transform when specified?

Upvotes: 15

Views: 22148

Answers (5)

Ivan
Ivan

Reputation: 40708

Usually a workaround is to apply the transform on the first image, retrieve the parameters of that transform, then apply with a deterministic transform with those parameters on the remaining images. However, here RandomChoice does not provide an API to get the parameters of the applied transform since it involves a variable number of transforms. In those cases, I usually implement an overwrite to the original function.

Looking at the torchvision implementation, it's as simple as:

class RandomChoice(RandomTransforms):
    def __call__(self, img):
        t = random.choice(self.transforms)
        return t(img)

Here are two possible solutions.

  1. You can either sample from the transform list on __init__ instead of on __call__:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.t = random.choice(self.transforms)
    
        def __call__(self, img):
            return self.t(img)
    

    So you can do:

    transform = RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    display(transform(img_a)) # both img_a and img_b will
    display(transform(img_b)) # have the same transform
    
    transform = RandomChoice([
        T.RandomHorizontalFlip(), 
        T.RandomVerticalFlip()
    ])
    display(transform(img_c)) # both img_c and img_d will
    display(transform(img_d)) # have the same transform
    

  1. Or better yet, transform the images in batch:

    import random
    import torchvision.transforms as T
    
    class RandomChoice(torch.nn.Module):
        def __init__(self, transforms):
           super().__init__()
           self.transforms = transforms
    
        def __call__(self, imgs):
            t = random.choice(self.transforms)
            return [t(img) for img in imgs]
    

    Which allows to do:

    transform = RandomChoice([
         T.RandomHorizontalFlip(), 
         T.RandomVerticalFlip()
    ])
    
    img_at, img_bt = transform([img_a, img_b])
    display(img_at) # both img_a and img_b will
    display(img_bt) # have the same transform
    
    img_ct, img_dt = transform([img_c, img_d])
    display(img_ct) # both img_c and img_d will
    display(img_dt) # have the same transform
    

Upvotes: 11

Ivan Gonzalez
Ivan Gonzalez

Reputation: 576

Referencing Random transforms for both input and target? I think this is probably the cleanest way to do it. Save the random state before applying any transformation and the just restore it for each consequent call

t = transforms.RandomRotation(degrees=360)
state = torch.get_rng_state()
x = t(x)
torch.set_rng_state(state)
y = t(y)

Upvotes: 6

Addison Klinke
Addison Klinke

Reputation: 1186

I realize the OP requested a solution using torchvision and I think @Ivan's answer does a good job addressing this.

However, for those not tied to a specific augmentation library, I wanted to point out that Albumentations appears to handle these kind of situations nicely in a native fashion by allowing the user to pass multiple source images, boxes, etc into the same transform. The return is structured as a dict

import albumentations as A

transform = A.Compose(
    transforms=[
        A.VerticalFlip(p=0.5),
        A.HorizontalFlip(p=0.5)],
    additional_targets={'image0': 'image', 'image1': 'image'}
)
transformed = transform(image=image, image0=image0, image1=image1)

Now you can access transformed['image0'], transformed['image1'], etc and all of them will have random parameters applied

Upvotes: 4

Abhi25t
Abhi25t

Reputation: 4683

Simply, take the randomization part out of PyTorch into an if statement. Below code uses vflip. Similarly for horizontal or other transforms.

import random
import torchvision.transforms.functional as TF

if random.random() > 0.5:
    image = TF.vflip(image)
    mask  = TF.vflip(mask)

This issue has been discussed in PyTorch forum. Several solutions' pros and cons were discussed on the official GitHub repository page. PyTorch maintainers have suggested this simple approach.

Do not use torchvision.transforms.RandomVerticalFlip(p=1). Use torchvision.transforms.functional.vflip

Functional transforms give you fine-grained control of the transformation pipeline. As opposed to the transformations above, functional transforms don’t contain a random number generator for their parameters. That means you have to specify/generate all parameters, but you can reuse the functional transform.

Upvotes: 4

Salman Hammad
Salman Hammad

Reputation: 41

I dont know of a function to fix the random output. maybe try a different logic, like creating the randomization yourself to be able to reuse the same transformation. logic:

  • generate a random number
  • based on the number apply a transformation on both images
  • generate another random number
  • do the same for the other two images try this:
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

if random.random() > 0.5:
        image_a_flipped = transforms.functional_pil.vflip(img_a)
        image_b_flipped = transforms.functional_pil.vflip(img_b)
else:
    image_a_flipped = transforms.functional_pil.hflip(img_a)
    image_b_flipped = transforms.functional_pil.hflip(img_b)

if random.random() > 0.5:
        image_c_flipped = transforms.functional_pil.vflip(img_c)
        image_d_flipped = transforms.functional_pil.vflip(img_d)
else:
    image_c_flipped = transforms.functional_pil.hflip(img_c)
    image_d_flipped = transforms.functional_pil.hflip(img_d)
    
display(image_a_flipped)
display(image_b_flipped)

display(image_c_flipped)
display(image_d_flipped)

Upvotes: 0

Related Questions