Reputation: 485
I am studying the documentation at https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html.
In the parameters section, it states
return_indices – if True, will return the max indices along with the outputs. Useful for torch.nn.MaxUnpool2d later
Could someone explain to me what max indices mean here? I believe it is the indices corresponding to the maximal value. If the maximal value is unique, does that mean only 1 index will be returned?
Upvotes: 0
Views: 3915
Reputation: 2569
I assume you already know how max pooling works. Then, let's print some results to get more insights.
import torch
import torch.nn as nn
pool = nn.MaxPool2d(kernel_size=2, return_indices=True)
input = torch.zeros(1, 1, 4, 4)
input[..., 0, 1] = input[..., 1, 3] = input[..., 2, 2] = input[..., 3, 0] = 1.
print(input)
tensor([[[[0., 1., 0., 0.],
[0., 0., 0., 1.],
[0., 0., 1., 0.],
[1., 0., 0., 0.]]]])
output, indices = pool(input)
print(output)
tensor([[[[1., 1.],
[1., 1.]]]])
print(indices)
tensor([[[[ 1, 7],
[12, 10]]]])
If you stretch
the input tensor and make it 1d, you can see that indices
contains the positions of each 1
value (the maximum for each window of MaxPool2d). As written in the documentation of torch.nn.MaxPool2d
, indices
is required for the torch.nn.MaxUnpool2d
module:
MaxUnpool2d takes in as input the output of MaxPool2d including the indices of the maximal values and computes a partial inverse in which all non-maximal values are set to zero.
Upvotes: 3