Reputation: 1151
I want to extract image patches from an image with patch size 128 and stride 32, so I have this code, but it gives me an error :
from PIL import Image
img = Image.open("cat.jpg")
x = transforms.ToTensor()(img)
x = x.unsqueeze(0)
size = 128 # patch size
stride = 32 # patch stride
patches = x.unfold(1, size, stride).unfold(2, size, stride).unfold(3, size, stride)
print(patches.shape)
and the error I get is :
RuntimeError: maximum size for tensor at dimension 1 is 3 but size is 128
This is the only method I've found so far. but it gives me this error
Upvotes: 4
Views: 4611
Reputation: 32982
The size of your x
is [1, 3, height, width]
. Calling x.unfold(1, size, stride)
tries to create slices of size 128 from dimension 1, which has size 3, hence it is too small to create any slice.
You don't want to create slices across dimension 1, since those are the channels of the image (RGB in this case) and they need to be kept as they are for all patches. The patches are only created across the height and width of an image.
patches = x.unfold(2, size, stride).unfold(3, size, stride)
The resulting tensor will have size [1, 3, num_vertical_slices, num_horizontal_slices, 128, 128]
. You can reshape it to combine the slices to get a list of patches i.e. size of [1, 3, num_patches, 128, 128]
:
patches = patches.reshape(1, 3, -1, size, size)
Upvotes: 11