Reputation: 75
I am working with 172x220x156 shaped 3D images. To feed the image into the network for output I need to extract patches of size 32x32x32 from the image and add those back to get the image again. Since my image dimension are not multiples of patch size thus I have to get overlapping patches. I want to know how to do that.
I am working in PyTorch, there are some options like unfold
and fold
but I am not sure how they work.
Upvotes: 0
Views: 1919
Reputation: 419
To extract (overlapping-) patches and to reconstruct the input shape we can use the torch.nn.functional.unfold
and the inverse operation torch.nn.functional.fold
. These methods only process 4D tensors or 2D images, however you can use these methods to process one dimension at a time.
Few notes:
This way requires fold/unfold methods from pytorch, unfortunately I have yet to find a similar method in the TF api.
We can extract patches in 2 ways, their output is the same. The methods are called extract_patches_3d
and extract_patches_3ds
where X is the number of dimensions. The latter uses torch.Tensor.unfold() and has less lines of code. (output is the same, except it cannot use dilation)
The methods extract_patches_Xd
and combine_patches_Xd
are inverse methods and the combiner reverses the steps from the extracter step by step.
The lines of code are followed by a comment stating the dimensionality such as (B, C, D, H, W). The following are used:
B
: Batch sizeC
: ChannelsD
: Depth DimensionH
: Height DimensionW
: Width Dimensionx_dim_in
: In the extraction method, this is the number input pixels in dimension x
. In the combining method, this is the number of number of sliding windows in dimension x
.x_dim_out
: In the extraction method, this is the number of sliding windows in dimension x
. In the combining method, this is the number output pixels in dimension x
.I have a public notebook to try out the code
The get_dim_blocks()
method is the function given on the pytorch docs website to compute the output shape of a convolutional layer.
Note that if you have overlapping patches and you combine them, the overlapping elements will be summed. If you would like to get the initial input again there is a way.
torch.ones_like(patches_tensor)
.fold
and unfold
where we first apply the fold
to the D
dimension and leave the W
and H
untouched by setting kernel to 1, padding to 0, stride to 1 and dilation to 1. After we review the tensor and fold over the H
and W
dimensions. The unfolding happens in reverse, starting with H
and W
, then D
.def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_3d(a, 2, padding=1, stride=1)
bs = extract_patches_3ds(a, 2, padding=1, stride=1)
print(b.shape)
print(b)
c = combine_patches_3d(b, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(c.shape)
print(c)
ones = torch.ones_like(b)
ones = combine_patches_3d(ones, (2,2,2,4,4), kernel_size=2, padding=1, stride=1)
print(torch.all(a==c))
print(c.shape, ones.shape)
d = c / ones
print(d)
print(torch.all(a==d))
Output (3D)
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
torch.Size([150, 2, 2, 2, 2])
tensor([[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 1.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 1., 2.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 2., 3.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 3., 4.]]]],
[[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 4., 0.]]],
[[[ 0., 0.],
[ 0., 0.]],
[[ 0., 1.],
[ 0., 5.]]]],
...,
[[[[124., 0.],
[128., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[ 0., 125.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[125., 126.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[126., 127.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]],
[[[[127., 128.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]],
[[[128., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 0.]]]]])
torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 8., 16., 24., 32.],
[ 40., 48., 56., 64.],
[ 72., 80., 88., 96.],
[ 104., 112., 120., 128.]],
[[ 136., 144., 152., 160.],
[ 168., 176., 184., 192.],
[ 200., 208., 216., 224.],
[ 232., 240., 248., 256.]]],
[[[ 264., 272., 280., 288.],
[ 296., 304., 312., 320.],
[ 328., 336., 344., 352.],
[ 360., 368., 376., 384.]],
[[ 392., 400., 408., 416.],
[ 424., 432., 440., 448.],
[ 456., 464., 472., 480.],
[ 488., 496., 504., 512.]]]],
[[[[ 520., 528., 536., 544.],
[ 552., 560., 568., 576.],
[ 584., 592., 600., 608.],
[ 616., 624., 632., 640.]],
[[ 648., 656., 664., 672.],
[ 680., 688., 696., 704.],
[ 712., 720., 728., 736.],
[ 744., 752., 760., 768.]]],
[[[ 776., 784., 792., 800.],
[ 808., 816., 824., 832.],
[ 840., 848., 856., 864.],
[ 872., 880., 888., 896.]],
[[ 904., 912., 920., 928.],
[ 936., 944., 952., 960.],
[ 968., 976., 984., 992.],
[1000., 1008., 1016., 1024.]]]]])
tensor(False)
torch.Size([2, 2, 2, 4, 4]) torch.Size([2, 2, 2, 4, 4])
tensor([[[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[ 13., 14., 15., 16.]],
[[ 17., 18., 19., 20.],
[ 21., 22., 23., 24.],
[ 25., 26., 27., 28.],
[ 29., 30., 31., 32.]]],
[[[ 33., 34., 35., 36.],
[ 37., 38., 39., 40.],
[ 41., 42., 43., 44.],
[ 45., 46., 47., 48.]],
[[ 49., 50., 51., 52.],
[ 53., 54., 55., 56.],
[ 57., 58., 59., 60.],
[ 61., 62., 63., 64.]]]],
[[[[ 65., 66., 67., 68.],
[ 69., 70., 71., 72.],
[ 73., 74., 75., 76.],
[ 77., 78., 79., 80.]],
[[ 81., 82., 83., 84.],
[ 85., 86., 87., 88.],
[ 89., 90., 91., 92.],
[ 93., 94., 95., 96.]]],
[[[ 97., 98., 99., 100.],
[101., 102., 103., 104.],
[105., 106., 107., 108.],
[109., 110., 111., 112.]],
[[113., 114., 115., 116.],
[117., 118., 119., 120.],
[121., 122., 123., 124.],
[125., 126., 127., 128.]]]]])
tensor(True)
Upvotes: 1
Reputation: 3117
You can use unfold
(pytorch docs):
batch_size, n_channels, n_rows, n_cols = 1, 172, 220, 156
x = torch.arange(batch_size*n_channels*n_rows*n_cols).view(batch_size, n_channels, n_rows, n_cols)
kernel_c, kernel_h, kernel_w = 32, 32, 32
step = 32
# Tensor.unfold(dimension, size, step)
windows_unpacked = x.unfold(1, kernel_c, step).unfold(2, kernel_h, step).unfold(3, kernel_w, step)
print(windows_unpacked.shape)
# result: torch.Size([1, 5, 6, 4, 32, 32, 32])
windows = windows_unpacked.permute(1, 2, 3, 0, 4, 5, 6).reshape(-1, kernel_c, kernel_h, kernel_w)
print(windows.shape)
# result: torch.Size([120, 32, 32, 32])
Upvotes: 1
Reputation: 177
Is all your data exactly 172x220x156
? If so, seems like you could just use a for loop and index into the tensor to get 32x32x32
blocks, correct? (possibly hardcoding a few things).
However, I'm not able to totally answer the question because it's not clear how you want to combine the results. To be clear, is this your goal?
1) get a 32x32x32
patch from an image
2) do some arbitrary processing on it
3) save that patch to some result
at the correct index
4) repeat
If so, how do you plan on combining the overlapping patches? Sum them? Average them?
However - the indexing:
out_tensor = torch.zeros_like(input)
for i_idx in [0, 32, 64, 96, 128, 140]:
for j_idx in [0, 32, 64, 96, 128, 160, 188]:
for k_idx in [0, 32, 64, 96, 124]:
input = tensor[i_idx, j_idx, k_idx]
output = your_model(input)
out_tensor[i_idx, j_idx, k_idx] = output
this isn't optimized at all, but I imagine the bulk of the computation will be the actual neural network, and there's no way around that, so optimization might be pointless.
Upvotes: 0