Rnj
Rnj

Reputation: 1189

Understanding torch.nn.Flatten

I understand that Flatten removes all of the dimensions except for one. For example, I understand flatten():

> t = torch.ones(4, 3)
> t
tensor([[1., 1., 1.],
    [1., 1., 1.],
    [1., 1., 1.],
    [1., 1., 1.]])

> flatten(t)
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])

However, I don't get Flatten, especially I don't get meaning of this snippet from the doc:

>>> input = torch.randn(32, 1, 5, 5)
>>> m = nn.Sequential(
>>>     nn.Conv2d(1, 32, 5, 1, 1),
>>>     nn.Flatten()
>>> )
>>> output = m(input)
>>> output.size()
torch.Size([32, 288])

I felt the output should have size [160], because 32*5=160.

Q1. So why it outputted size [32,288]?

Q2. I also don't get meaning of shape information given in the doc:

enter image description here

Q3. And also meaning of parameters:

enter image description here

Upvotes: 7

Views: 10804

Answers (1)

GoodDeeds
GoodDeeds

Reputation: 8527

It is a difference in the default behaviour. torch.flatten flattens all dimensions by default, while torch.nn.Flatten flattens all dimensions starting from the second dimension (index 1) by default.

You can see this behaviour in the default values of the start_dim and end_dim arguments. The start_dim argument denotes the first dimension to be flattened (zero-indexed), and the end_dim argument denotes the last dimension to be flattened. So, when start_dim=1, which is the default for torch.nn.Flatten, the first dimension (index 0) is not flattened, but it is included when start_dim=0, which is the default for torch.flatten.

The reason behind this difference is probably because torch.nn.Flatten is intended to be used with torch.nn.Sequential, where typically a series of operations are performed on a batch of inputs, where each input is treated independently of the others. For example, if you have a batch of images and you call torch.nn.Flatten, the typical use case would be to flatten each image separately, and not flatten the whole batch.

If you do want to flatten all dimensions using torch.nn.Flatten, you can simply create the object as torch.nn.Flatten(start_dim=0).

Finally, the shape information in the docs just covers how the shape of the tensor will be affected, illustrating that the first (index 0) dimension is left as it is. So, if you have an input tensor of shape (N, *dims), where *dims is an arbitrary sequence of dimensions, the output tensor will have the shape (N, product of *dims), since all dimensions except the batch dimension are flattened. For example, an input of shape (3,10,10) will have an output of shape (3, 10 x 10) = (3, 100).

Upvotes: 7

Related Questions