Reputation: 1743
Torchvision transforms v2 promises to apply transform to both inputs similarly, however that doesn't seem to happen:
import torchvision.transforms.v2 as transforms_v2
joint_transform = transforms_v2.Compose([
transforms_v2.ToTensor(),
transforms_v2.RandomHorizontalFlip(p=0.5),
transforms_v2.RandomVerticalFlip(p=0.5),
transforms_v2.RandomRotation(degrees=45),
])
X_pre, y_pre = torch.rand(3, 512, 512), torch.rand(1, 512, 512)
X_post, y_post = joint_transform(X_pre, y_pre)
print((X_pre == X_post).all())
print((y_pre == y_post).all())
> True
> False
y_post, X_post = joint_transform(y_pre, X_pre)
print((X_pre == X_post).all())
print((y_pre == y_post).all())
> False
> True
It seems like the transform is only applied to the first argument.
What am I missing?
Upvotes: 0
Views: 12
Reputation: 1743
Apparently I was missing a lot:
One has to convert / wrap both image and mask with tv_tensor.Image
and tv_tensor.Mask
Independently of that, the joint_transform needs to put ToTensor
as the last operation
ToTensor
has been deprecated and needs to be replaced with transforms.Compose([transforms.ToImageTensor(), transforms.ConvertImageDtype()])
Upvotes: 0