John Sall
John Sall

Reputation: 1151

How to extract patches from an image in pytorch?

I want to extract image patches from an image with patch size 128 and stride 32, so I have this code, but it gives me an error :

from PIL import Image 
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)

x = x.unsqueeze(0)

size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)

and the error I get is :

RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128

This is the only method I've found so far. but it gives me this error

Upvotes: 4

Views: 4611

Answers (1)

Michael Jungo
Michael Jungo

Reputation: 32982

The size of your x is [1, 3, height, width]. Calling x.unfold(1, size, stride) tries to create slices of size 128 from dimension 1, which has size 3, hence it is too small to create any slice.

You don't want to create slices across dimension 1, since those are the channels of the image (RGB in this case) and they need to be kept as they are for all patches. The patches are only created across the height and width of an image.

patches = x.unfold(2, size, stride).unfold(3, size, stride)

The resulting tensor will have size [1, 3, num_vertical_slices, num_horizontal_slices, 128, 128]. You can reshape it to combine the slices to get a list of patches i.e. size of [1, 3, num_patches, 128, 128]:

patches = patches.reshape(1, 3, -1, size, size)

Upvotes: 11

Related Questions