Reputation: 489
How to perform sum pooling in PyTorch. Specifically, if we have input (N, C, W_in, H_in)
and want output (N, C, W_out, H_out)
using a particular kernel_size
and stride
just like nn.Maxpool2d
?
Upvotes: 4
Views: 8406
Reputation: 56
https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html#torch.nn.AvgPool2d find divisor_override.
set divisor_override=1
you'll get a sumpool
import torch
input = torch.tensor([[[1,2,3],[3,2,1],[3,4,5]]])
sumpool = torch.nn.AvgPool2d(2, stride=1, divisor_override=1)
sumpool(input)
you'll get
tensor([[[ 8, 8],
[12, 12]]])
Upvotes: 3
Reputation: 2105
To expand on benjaminplanche's answer:
I need sum pooling as well and it doesn't seem to directly exist, but it is equivalent to running a conv2d with a weights parameter made of ones. I thought it would be faster to run AvgPool2d and multiply by the kernel size product. Turns out, not exactly.
Bottom line up front:
Use torch.nn.functional.avg_pool2d and its related functions and multiply by the kernel size.
Testing in Jupyter I find:
(Overhead)
%%timeit
x = torch.rand([1,1,1000,1000])
>>> 3.49 ms ± 4.72 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
%%timeit
_=F.avg_pool2d(torch.rand([1,1,1000,1000]), [10,10])*10*10
>>> 4.99 ms ± 74.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(So 1.50 ms ± 79.0 µs) (I found the *10*10
only adds around 20 µs to the graph)
avePool = nn.AvgPool2d([10, 10], 1, 0)
%%timeit
_=avePool(torch.rand([1,1,1000,1000]))*10*10
>>> 80.9 ms ± 1.57 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
(So 77.4 ms ± 1.58 ms)
y = torch.ones([1,1,10,10])
%%timeit
_=F.conv2d(torch.rand([1,1,1000,1000]), y)
>>> 14.4 ms ± 421 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(So 10.9 ms ± 426 µs)
sumPool = nn.Conv2d(1, 1, 10, 1, 0, 1, 1, False)
sumPool.weight = torch.nn.Parameter(y)
%%timeit
_=sumPool(torch.rand([1,1,1000,1000]))
>>> 7.24 ms ± 63.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
(So 3.75 ms ± 68.3 µs)
And as a sanity check.
abs_err = torch.max(torch.abs(avePool(x)*10*10 - sumPool(x)))
magnitude = torch.max(torch.max(avePool(x)*10*10, torch.max(sumPool(x))))
relative_err = abs_err/magnitude
abs_err.item(), magnitude.item(), relative_err.item()
>>> (3.814697265625e-06, 62.89910125732422, 6.064788493631568e-08)
That's probably a reasonable rounding related error.
I do not know why the functional version is faster than making a dedicated kernel, but it looks like if you want to make a dedicated kernel, prefer the Conv2D version, and make the weights untrainable with sumPool.weights.requires_grad = False
or with torch.no_grad():
during creation of the kernel parameters. These results may change with kernel size, so test for your own application if you need to speed up this part. Let me know if I missed something...
Upvotes: 1
Reputation: 15119
You could use torch.nn.AvgPool1d
(or torch.nn.AvgPool2d
, torch.nn.AvgPool3d
) which are performing mean pooling - proportional to sum pooling. If you really want the summed values, you could multiply the averaged output by the pooling surface.
Upvotes: 8