Reputation: 323
What are the differences between torch.flatten()
and torch.nn.Flatten()
?
Upvotes: 19
Views: 23917
Reputation: 61305
You can think of the job of torch.flatten()
as to simply doing a flattening operation of the tensor, without any strings attached. You give a tensor, it flattens, and returns it. That's all there to it.
On the contrary, nn.Flatten()
is much more sophisticated (i.e., it's a neural net layer). Being object oriented, it inherits from nn.Module
, although it internally uses the plain tensor.flatten() OP in the forward()
method for flattening the tensor. You can think of it more like a syntactic sugar over torch.flatten()
.
Important difference: A notable distinction is that torch.flatten()
always returns an 1D tensor as result, provided that the input is at least 1D or greater, whereas nn.Flatten()
always returns a 2D tensor, provided that the input is at least 2D or greater (With 1D tensor as input, it will throw an IndexError).
torch.flatten()
is an API whereas nn.Flatten()
is a neural net layer.
torch.flatten()
is a python function whereas nn.Flatten()
is a python class.
because of the above point, nn.Flatten()
comes with lot of methods and attributes
torch.flatten()
can be used in the wild (e.g., for simple tensor OPs) whereas nn.Flatten()
is expected to be used in a nn.Sequential()
block as one of the layers.
torch.flatten()
has no information about the computation graph unless it is stuck into other graph-aware block (with tensor.requires_grad
flag set to True
) whereas nn.Flatten()
is always being tracked by autograd.
torch.flatten()
cannot accept and process (e.g., linear/conv1D) layers as inputs whereas nn.Flatten()
is mostly used for processing these neural net layers.
both torch.flatten()
and nn.Flatten()
return views to input tensor. Thus, any modification to the result also affects the input tensor. (See the code below)
Code demo:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
Flattening with torch.flatten()
:
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
Flattening with nn.Flatten()
:
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
tidbit: torch.flatten()
is the precursor to nn.Flatten()
and its brethren nn.Unflatten()
since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten()
, since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245.
There are also recent proposals to use nn.Flatten()
in ResNets for model surgery.
Upvotes: 6
Reputation: 40618
Flattening is available in three forms in PyTorch
As a tensor method (oop style) torch.Tensor.flatten
applied directly on a tensor: x.flatten()
.
As a function (functional form) torch.flatten
applied as: torch.flatten(x)
.
As a module (layer nn.Module
) nn.Flatten()
. Generally used in a model definition.
All three are identical and share the same implementation, the only difference being nn.Flatten
has start_dim
set to 1
by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0
to axis=-1
- i.e. the entire tensor - if no arguments are given.
Upvotes: 25