Reputation: 1189
Given a batch of images, I'd like to extract all possible image patches, similar to a convolution. In TensorFlow, we can use tf.extract_image_patches
to achieve this. Is there an equivalent function in PyTorch?
Thank you.
Upvotes: 19
Views: 14720
Reputation: 51
You can use kornia extract_tensor_patches. Explore also the APIs to combine back the patches, if needed.
Upvotes: 2
Reputation: 363
Spent a bit of time looking into this as well and I found this pytorch thread that was useful for me with PyTorch dev ptrblck (bless this dude) giving an equivalent pytorch version of the tensorflow function.
I'll just repost the code (from user FloCF) here for simplicity.
import math
import torch.nn.functional as F
def extract_image_patches(x, kernel, stride=1, dilation=1):
# Do TF 'SAME' Padding
b,c,h,w = x.shape
h2 = math.ceil(h / stride)
w2 = math.ceil(w / stride)
pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
# Extract patches
patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
patches = patches.permute(0,4,5,1,2,3).contiguous()
return patches.view(b,-1,patches.shape[-2], patches.shape[-1])
Give those guys a like on the PyTorch forum :)
Upvotes: 4
Reputation: 3117
Maybe this code example will help to understand how to use unfold
, inspired by this thread linked by @gasoon, but a bit more verbose:
batch_size, n_channels, n_rows, n_cols = 32, 3, 64, 64
kernel_h, kernel_w = 7, 9
step = 5
x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)
# unfold(dimension, size, step)
windows = x.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(2, 3, 0, 1, 4, 5).reshape(-1, n_channels, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([4608, 3, 7, 9]) = [n_windows, n_channels, krenel_h, kernel_w]
Upvotes: 7
Reputation: 865
Unfortunately, there might not be a direct way to achieve your goal.
But Tensor.unfold function might be a solution.
https://discuss.pytorch.org/t/how-to-extract-smaller-image-patches-3d/16837/2
This website might help you.
Upvotes: 7