Gabriel Bianconi
Gabriel Bianconi

Reputation: 1189

Is there a function to extract image patches in PyTorch?

Given a batch of images, I'd like to extract all possible image patches, similar to a convolution. In TensorFlow, we can use tf.extract_image_patches to achieve this. Is there an equivalent function in PyTorch?

Thank you.

Upvotes: 19

Views: 14720

Answers (4)

edgarriba
edgarriba

Reputation: 51

You can use kornia extract_tensor_patches. Explore also the APIs to combine back the patches, if needed.

Upvotes: 2

maraboule
maraboule

Reputation: 363

Spent a bit of time looking into this as well and I found this pytorch thread that was useful for me with PyTorch dev ptrblck (bless this dude) giving an equivalent pytorch version of the tensorflow function.

I'll just repost the code (from user FloCF) here for simplicity.

import math
import torch.nn.functional as F

def extract_image_patches(x, kernel, stride=1, dilation=1):
    # Do TF 'SAME' Padding
    b,c,h,w = x.shape
    h2 = math.ceil(h / stride)
    w2 = math.ceil(w / stride)
    pad_row = (h2 - 1) * stride + (kernel - 1) * dilation + 1 - h
    pad_col = (w2 - 1) * stride + (kernel - 1) * dilation + 1 - w
    x = F.pad(x, (pad_row//2, pad_row - pad_row//2, pad_col//2, pad_col - pad_col//2))
    
    # Extract patches
    patches = x.unfold(2, kernel, stride).unfold(3, kernel, stride)
    patches = patches.permute(0,4,5,1,2,3).contiguous()
    
    return patches.view(b,-1,patches.shape[-2], patches.shape[-1])

Give those guys a like on the PyTorch forum :)

Upvotes: 4

DalyaG
DalyaG

Reputation: 3117

Maybe this code example will help to understand how to use unfold, inspired by this thread linked by @gasoon, but a bit more verbose:

batch_size, n_channels, n_rows, n_cols = 32, 3, 64, 64
kernel_h, kernel_w = 7, 9
step = 5

x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)

# unfold(dimension, size, step)
windows = x.unfold(2, kernel_h, step).unfold(3, kernel_w, step).permute(2, 3, 0, 1, 4, 5).reshape(-1, n_channels, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([4608, 3, 7, 9]) = [n_windows, n_channels, krenel_h, kernel_w]

Upvotes: 7

gasoon
gasoon

Reputation: 865

Unfortunately, there might not be a direct way to achieve your goal.
But Tensor.unfold function might be a solution.
https://discuss.pytorch.org/t/how-to-extract-smaller-image-patches-3d/16837/2
This website might help you.

Upvotes: 7

Related Questions