Reputation: 419
Pytorch offers torch.Tensor.unfold
operation which can be chained to arbitrarily many dimensions to extract overlapping patches. How can we reverse the patch extraction operation such that the patches are combined to the input shape.
The focus is 3D volumetric images with 1 channel (biomedical). Extracting is possible with unfold
, how can we combine the patches if they overlap.
Upvotes: 4
Views: 1887
Reputation: 419
The above solution makes copies in memory as it keeps the patches contiguous. This leads to memory issues for large volumes with many overlapping voxels. To extract patches without making a copy in memory we can do the following in pytorch:
def get_dim_blocks(dim_in, kernel_size, padding=0, stride=1, dilation=1):
return (dim_in + 2 * padding - dilation * (kernel_size - 1) - 1) // stride + 1
def extract_patches_3d(x, kernel_size, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
x = x.contiguous()
channels, depth, height, width = x.shape[-4:]
d_blocks = get_dim_blocks(depth, kernel_size=kernel_size[0], stride=stride[0], dilation=dilation[0])
h_blocks = get_dim_blocks(height, kernel_size=kernel_size[1], stride=stride[1], dilation=dilation[1])
w_blocks = get_dim_blocks(width, kernel_size=kernel_size[2], stride=stride[2], dilation=dilation[2])
shape = (channels, d_blocks, h_blocks, w_blocks, kernel_size[0], kernel_size[1], kernel_size[2])
strides = (width*height*depth,
stride[0]*width*height,
stride[1]*width,
stride[2],
dilation[0]*width*height,
dilation[1]*width,
dilation[2])
x = x.as_strided(shape, strides)
x = x.permute(1,2,3,0,4,5,6)
return x
The method expect tensor in shape `(B,C,D,H,W). The method is based on this and this answer (in numpy) which explain in more detail what memory stride does. The output will be non-contiguous and the first 3 dimensions will be the number of blocks or sliding windows in the D, H and W dimension. Combining into 1 dimension is not possible as this would require a copy to contiguous memory.
Test with stride
a = torch.arange(81, dtype=torch.float32).view(1,3,3,3,3)
print(a)
b = extract_patches_3d(a, kernel_size=2, stride=2)
print(b.shape)
print(b.storage())
print(a.data_ptr() == b.data_ptr())
print(b)
Output
tensor([[[[[ 0., 1., 2.],
[ 3., 4., 5.],
[ 6., 7., 8.]],
[[ 9., 10., 11.],
[12., 13., 14.],
[15., 16., 17.]],
[[18., 19., 20.],
[21., 22., 23.],
[24., 25., 26.]]],
[[[27., 28., 29.],
[30., 31., 32.],
[33., 34., 35.]],
[[36., 37., 38.],
[39., 40., 41.],
[42., 43., 44.]],
[[45., 46., 47.],
[48., 49., 50.],
[51., 52., 53.]]],
[[[54., 55., 56.],
[57., 58., 59.],
[60., 61., 62.]],
[[63., 64., 65.],
[66., 67., 68.],
[69., 70., 71.]],
[[72., 73., 74.],
[75., 76., 77.],
[78., 79., 80.]]]]])
torch.Size([1, 1, 1, 3, 2, 2, 2])
0.0
1.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
10.0
11.0
12.0
13.0
14.0
15.0
16.0
17.0
18.0
19.0
20.0
21.0
22.0
23.0
24.0
25.0
26.0
27.0
28.0
29.0
30.0
31.0
32.0
33.0
34.0
35.0
36.0
37.0
38.0
39.0
40.0
41.0
42.0
43.0
44.0
45.0
46.0
47.0
48.0
49.0
50.0
51.0
52.0
53.0
54.0
55.0
56.0
57.0
58.0
59.0
60.0
61.0
62.0
63.0
64.0
65.0
66.0
67.0
68.0
69.0
70.0
71.0
72.0
73.0
74.0
75.0
76.0
77.0
78.0
79.0
80.0
[torch.FloatStorage of size 81]
True
tensor([[[[[[[ 0., 1.],
[ 3., 4.]],
[[ 9., 10.],
[12., 13.]]],
[[[27., 28.],
[30., 31.]],
[[36., 37.],
[39., 40.]]],
[[[54., 55.],
[57., 58.]],
[[63., 64.],
[66., 67.]]]]]]])
Reversing with summation of overlapping voxels using memory stride is not possible assuming that the tensor is contiguous (as it would be after processing in NN). However you can manually sum them as explained as above, or with slicing as explained here.
Upvotes: 2
Reputation: 419
To extract (overlapping-) patches and to reconstruct the input shape we can use the torch.nn.functional.unfold
and the inverse operation torch.nn.functional.fold
. These methods only process 4D tensors or 2D images, however you can use these methods to process one dimension at a time.
Few notes:
This way requires fold/unfold methods from pytorch, unfortunately I have yet to find a similar method in the TF api.
We start with 2D then 3D then 4D to show the incremental differences, you can extend to arbitrarily many dimensions (probably write a loop instead of hardcoding each dimension like i did)
We can extract patches in 2 ways, their output is the same. The methods are called extract_patches_Xd
and extract_patches_Xds
where X is the number of dimensions. The latter uses torch.Tensor.unfold() and has less lines of code. (output is the same, except it cannot use dilation)
The methods extract_patches_Xd
and combine_patches_Xd
are inverse methods and the combiner reverses the steps from the extracter step by step.
The lines of code are followed by a comment stating the dimensionality such as (B, C, T, D, H, W). The following are used:
B
: Batch sizeC
: ChannelsT
: Time DimensionD
: Depth DimensionH
: Height DimensionW
: Width Dimensionx_dim_in
: In the extraction method, this is the number input pixels in dimension x
. In the combining method, this is the number of number of sliding windows in dimension x
.x_dim_out
: In the extraction method, this is the number of sliding windows in dimension x
. In the combining method, this is the number output pixels in dimension x
.I have a public notebook to try out the code
I have tried out basic 2D, 3D and 4D tensors as shown below. However, my code is not infallible and I appreciate feedback when tested on other inputs.
The get_dim_blocks()
method is the function given on the pytorch docs website to compute the output shape of a convolutional layer.
Note that if you have overlapping patches and you combine them, the overlapping elements will be summed. If you would like to get the initial input again there is a way.
torch.ones_like(patches_tensor)
.First (2D):
The torch.nn.functional.fold
and torch.nn.functional.unfold
methods can be used directly.
import torch
def extract_patches_2ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1])
# (B, C, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1])
# (B * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1])
return x
def extract_patches_2d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
h_dim_in = x.shape[2]
w_dim_in = x.shape[3]
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
# (B, C, H, W)
x = torch.nn.functional.unfold(x, kernel_size, padding=padding, stride=stride, dilation=dilation)
# (B, C * kernel_size[0] * kernel_size[1], h_dim_out * w_dim_out)
x = x.view(-1, channels, kernel_size[0], kernel_size[1], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], kernel_size[1], h_dim_out, w_dim_out)
x = x.permute(0,1,4,5,2,3)
# (B, C, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1])
# (B * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1])
return x
def combine_patches_2d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding)
if isinstance(stride, int):
stride = (stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
h_dim_out, w_dim_out = output_shape[2:]
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
# (B * h_dim_in * w_dim_in, C, kernel_size[0], kernel_size[1])
x = x.view(-1, channels, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1])
# (B, C, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1])
x = x.permute(0,1,4,5,2,3)
# (B, C, kernel_size[0], kernel_size[1], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * kernel_size[1], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * kernel_size[1], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, (h_dim_out, w_dim_out), kernel_size=(kernel_size[0], kernel_size[1]), padding=padding, stride=stride, dilation=dilation)
# (B, C, H, W)
return x
a = torch.arange(1, 65, dtype=torch.float).view(2,2,4,4)
print(a.shape)
print(a)
b = extract_patches_2d(a, 2, padding=1, stride=2, dilation=1)
# b = extract_patches_2ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_2d(b, 2, (2,2,4,4), padding=1, stride=2, dilation=1)
print(c.shape)
print(c)
print(torch.all(a==c))
Output (2D)
torch.Size([2, 2, 4, 4])
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]],
[[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.],
[29., 30., 31., 32.]]],
[[[33., 34., 35., 36.],
[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]],
[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.],
[61., 62., 63., 64.]]]])
torch.Size([18, 2, 2, 2])
tensor([[[[ 0., 0.],
[ 0., 1.]],
[[ 0., 0.],
[ 2., 3.]]],
[[[ 0., 0.],
[ 4., 0.]],
[[ 0., 5.],
[ 0., 9.]]],
[[[ 6., 7.],
[10., 11.]],
[[ 8., 0.],
[12., 0.]]],
[[[ 0., 13.],
[ 0., 0.]],
[[14., 15.],
[ 0., 0.]]],
[[[16., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 17.]]],
[[[ 0., 0.],
[18., 19.]],
[[ 0., 0.],
[20., 0.]]],
[[[ 0., 21.],
[ 0., 25.]],
[[22., 23.],
[26., 27.]]],
[[[24., 0.],
[28., 0.]],
[[ 0., 29.],
[ 0., 0.]]],
[[[30., 31.],
[ 0., 0.]],
[[32., 0.],
[ 0., 0.]]],
[[[ 0., 0.],
[ 0., 33.]],
[[ 0., 0.],
[34., 35.]]],
[[[ 0., 0.],
[36., 0.]],
[[ 0., 37.],
[ 0., 41.]]],
[[[38., 39.],
[42., 43.]],
[[40., 0.],
[44., 0.]]],
[[[ 0., 45.],
[ 0., 0.]],
[[46., 47.],
[ 0., 0.]]],
[[[48., 0.],
[ 0., 0.]],
[[ 0., 0.],
[ 0., 49.]]],
[[[ 0., 0.],
[50., 51.]],
[[ 0., 0.],
[52., 0.]]],
[[[ 0., 53.],
[ 0., 57.]],
[[54., 55.],
[58., 59.]]],
[[[56., 0.],
[60., 0.]],
[[ 0., 61.],
[ 0., 0.]]],
[[[62., 63.],
[ 0., 0.]],
[[64., 0.],
[ 0., 0.]]]])
torch.Size([2, 2, 4, 4])
tensor([[[[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]],
[[17., 18., 19., 20.],
[21., 22., 23., 24.],
[25., 26., 27., 28.],
[29., 30., 31., 32.]]],
[[[33., 34., 35., 36.],
[37., 38., 39., 40.],
[41., 42., 43., 44.],
[45., 46., 47., 48.]],
[[49., 50., 51., 52.],
[53., 54., 55., 56.],
[57., 58., 59., 60.],
[61., 62., 63., 64.]]]])
tensor(True)
Second (3D):
Now it becomes interesting: We need to use 2 fold
and unfold
where we first apply the fold
to the D
dimension and leave the W
and H
untouched by setting kernel to 1, padding to 0, stride to 1 and dilation to 1. After we review the tensor and fold over the H
and W
dimensions. The unfolding happens in reverse, starting with H
and W
, then D
.
def extract_patches_3ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2])
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def extract_patches_3d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_in = x.shape[2]
h_dim_in = x.shape[3]
w_dim_in = x.shape[4]
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
x = x.view(-1, channels, d_dim_in, h_dim_in * w_dim_in)
# (B, C, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_out * kernel_size[1] * kernel_size[2], h_dim_out, w_dim_out)
x = x.view(-1, channels, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], d_dim_out, kernel_size[1], kernel_size[2], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 6, 7, 2, 4, 5)
# (B, C, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2])
# (B * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2])
return x
def combine_patches_3d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
# print(d_dim_in, h_dim_in, w_dim_in, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
# (B, C, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2])
x = x.permute(0, 1, 5, 2, 6, 7, 3, 4)
# (B, C, kernel_size[0], d_dim_in, kernel_size[1], kernel_size[2], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * d_dim_in * kernel_size[1] * kernel_size[2], h_dim_in * w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[1], kernel_size[2]), padding=(padding[1], padding[2]), stride=(stride[1], stride[2]), dilation=(dilation[1], dilation[2]))
# (B, C * kernel_size[0] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, D, H * W)
x = x.view(-1, channels, d_dim_out, h_dim_out, w_dim_out)
# (B, C, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,4,4)
print(a.shape)
print(a)
# b = extract_patches_3d(a, 2, padding=1, stride=2)
b = extract_patches_3ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_3d(b, 2, (2,2,2,4,4), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))
Output (3D)
(I had to limit the characters please look at the notebook)
Third (4D)
We add a time dimension to the 3D volume. We start the folding with just the T
dimension, leaving D
, H
and W
alone similarly to the 3D version. Then we fold over D
leaving H
and W
. Finally we do H
and W
. The unfolding happens in reverse again. Hopefully by now you notice a pattern and you can add arbitrarily many dimensions and start folding one by one. The unfolding happens in reverse again.
def extract_patches_4ds(x, kernel_size, padding=0, stride=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding, padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride, stride)
channels = x.shape[1]
x = torch.nn.functional.pad(x, padding)
# (B, C, T, D, H, W)
x = x.unfold(2, kernel_size[0], stride[0]).unfold(3, kernel_size[1], stride[1]).unfold(4, kernel_size[2], stride[2]).unfold(5, kernel_size[3], stride[3])
# (B, C, t_dim_out, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
# (B * t_dim_out, d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
return x
def extract_patches_4d(x, kernel_size, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
t_dim_in = x.shape[2]
d_dim_in = x.shape[3]
h_dim_in = x.shape[4]
w_dim_in = x.shape[5]
t_dim_out = get_dim_blocks(t_dim_in, kernel_size[0], padding[0], stride[0], dilation[0])
d_dim_out = get_dim_blocks(d_dim_in, kernel_size[1], padding[1], stride[1], dilation[1])
h_dim_out = get_dim_blocks(h_dim_in, kernel_size[2], padding[2], stride[2], dilation[2])
w_dim_out = get_dim_blocks(w_dim_in, kernel_size[3], padding[3], stride[3], dilation[3])
# print(t_dim_in, d_dim_in, h_dim_in, w_dim_in, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
# (B, C, T, D, H, W)
x = x.view(-1, channels, t_dim_in, d_dim_in * h_dim_in * w_dim_in)
# (B, C, T, D * H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C * kernel_size[0], t_dim_out * D * H * W)
x = x.view(-1, channels * kernel_size[0] * t_dim_out, d_dim_in, h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * t_dim_out, D, H * W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[1], 1), padding=(padding[1], 0), stride=(stride[1], 1), dilation=(dilation[1], 1))
# (B, C * kernel_size[0] * t_dim_out * kernel_size[1], d_dim_out * H * W)
x = x.view(-1, channels * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out, h_dim_in, w_dim_in)
# (B, C * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out, H, W)
x = torch.nn.functional.unfold(x, kernel_size=(kernel_size[2], kernel_size[3]), padding=(padding[2], padding[3]), stride=(stride[2], stride[3]), dilation=(dilation[2], dilation[3]))
# (B, C * kernel_size[0] * t_dim_out * kernel_size[1] * d_dim_out * kernel_size[2] * kernel_size[3], h_dim_out * w_dim_out)
x = x.view(-1, channels, kernel_size[0], t_dim_out, kernel_size[1], d_dim_out, kernel_size[2], kernel_size[3], h_dim_out, w_dim_out)
# (B, C, kernel_size[0], t_dim_out, kernel_size[1], d_dim_out, kernel_size[2], kernel_size[3], h_dim_out, w_dim_out)
x = x.permute(0, 1, 3, 5, 8, 9, 2, 4, 6, 7)
# (B, C, t_dim_out, d_dim_out, h_dim_out, w_dim_out, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
x = x.contiguous().view(-1, channels, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
# (B * t_dim_out * d_dim_out * h_dim_out * w_dim_out, C, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
return x
def combine_patches_4d(x, kernel_size, output_shape, padding=0, stride=1, dilation=1):
if isinstance(kernel_size, int):
kernel_size = (kernel_size, kernel_size, kernel_size, kernel_size)
if isinstance(padding, int):
padding = (padding, padding, padding, padding)
if isinstance(stride, int):
stride = (stride, stride, stride, stride)
if isinstance(dilation, int):
dilation = (dilation, dilation, dilation, dilation)
def get_dim_blocks(dim_in, dim_kernel_size, dim_padding = 0, dim_stride = 1, dim_dilation = 1):
dim_out = (dim_in + 2 * dim_padding - dim_dilation * (dim_kernel_size - 1) - 1) // dim_stride + 1
return dim_out
channels = x.shape[1]
t_dim_out, d_dim_out, h_dim_out, w_dim_out = output_shape[2:]
t_dim_in = get_dim_blocks(d_dim_out, kernel_size[0], padding[0], stride[0], dilation[0])
d_dim_in = get_dim_blocks(d_dim_out, kernel_size[1], padding[1], stride[1], dilation[1])
h_dim_in = get_dim_blocks(h_dim_out, kernel_size[2], padding[2], stride[2], dilation[2])
w_dim_in = get_dim_blocks(w_dim_out, kernel_size[3], padding[3], stride[3], dilation[3])
# print(t_dim_in, d_dim_in, h_dim_in, w_dim_in, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
x = x.view(-1, channels, t_dim_in, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
# (B, C, t_dim_in, d_dim_in, h_dim_in, w_dim_in, kernel_size[0], kernel_size[1], kernel_size[2], kernel_size[3])
x = x.permute(0, 1, 6, 2, 7, 3, 8, 9, 4, 5)
# (B, C, kernel_size[0], t_dim_in, kernel_size[1], d_dim_in, kernel_size[2], kernel_size[3], h_dim_in, w_dim_in)
x = x.contiguous().view(-1, channels * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in * kernel_size[2] * kernel_size[3], h_dim_in * w_dim_in)
# (B, C * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in * kernel_size[2] * kernel_size[3], h_dim_in, w_dim_in)
x = torch.nn.functional.fold(x, output_size=(h_dim_out, w_dim_out), kernel_size=(kernel_size[2], kernel_size[3]), padding=(padding[2], padding[3]), stride=(stride[2], stride[3]), dilation=(dilation[2], dilation[3]))
# (B, C * kernel_size[0] * t_dim_in * kernel_size[1] * d_dim_in, H, W)
x = x.view(-1, channels * kernel_size[0] * t_dim_in * kernel_size[1], d_dim_in * h_dim_out * w_dim_out)
# (B, C * kernel_size[0] * t_dim_in * kernel_size[1], d_dim_in * H * W)
x = torch.nn.functional.fold(x, output_size=(d_dim_out, h_dim_out * w_dim_out), kernel_size=(kernel_size[1], 1), padding=(padding[1], 0), stride=(stride[1], 1), dilation=(dilation[1], 1))
# (B, C * kernel_size[0] * t_dim_in, D, H * W)
x = x.view(-1, channels * kernel_size[0], t_dim_in * d_dim_out * h_dim_out * w_dim_out)
# (B, C * kernel_size[0], t_dim_in * D * H * W)
x = torch.nn.functional.fold(x, output_size=(t_dim_out, d_dim_out * h_dim_out * w_dim_out), kernel_size=(kernel_size[0], 1), padding=(padding[0], 0), stride=(stride[0], 1), dilation=(dilation[0], 1))
# (B, C, T, D * H * W)
x = x.view(-1, channels, t_dim_out, d_dim_out, h_dim_out, w_dim_out)
# (B, C, T, D, H, W)
return x
a = torch.arange(1, 129, dtype=torch.float).view(2,2,2,2,4,2)
print(a.shape)
print(a)
# b = extract_patches_4d(a, 2, padding=1, stride=2)
b = extract_patches_4ds(a, 2, padding=1, stride=2)
print(b.shape)
print(b)
c = combine_patches_4d(b, 2, (2,2,2,2,4,2), padding=1, stride=2)
print(c.shape)
print(c)
print(torch.all(a==c))
Output (4D)
(I had to limit the characters please look at the notebook)
Upvotes: 6