iggy
iggy

Reputation: 1743

Torchvision v2 joint transform

Torchvision transforms v2 promises to apply transform to both inputs similarly, however that doesn't seem to happen:

import torchvision.transforms.v2 as transforms_v2

joint_transform = transforms_v2.Compose([
    transforms_v2.ToTensor(),
    transforms_v2.RandomHorizontalFlip(p=0.5),
    transforms_v2.RandomVerticalFlip(p=0.5),
    transforms_v2.RandomRotation(degrees=45),
])

X_pre, y_pre = torch.rand(3, 512, 512), torch.rand(1, 512, 512)
X_post, y_post = joint_transform(X_pre, y_pre)

print((X_pre == X_post).all())
print((y_pre == y_post).all())
> True
> False

y_post, X_post = joint_transform(y_pre, X_pre)

print((X_pre == X_post).all())
print((y_pre == y_post).all())
> False
> True

It seems like the transform is only applied to the first argument.

What am I missing?

Upvotes: 0

Views: 12

Answers (1)

iggy
iggy

Reputation: 1743

Apparently I was missing a lot:

  1. One has to convert / wrap both image and mask with tv_tensor.Image and tv_tensor.Mask

  2. Independently of that, the joint_transform needs to put ToTensor as the last operation

  3. ToTensor has been deprecated and needs to be replaced with transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])

Upvotes: 0

Related Questions