Reputation: 209
PyTorch's nn.Linear(in_features, out_features)
accepts a tensor of size (N_batch, N_1, N_2, ..., N_end)
, where N_end = in_features
. The output is a tensor of size (N_batch, N_1, N_2, ..., out_features)
.
It isn't very clear to me how it behaves in the following situations:
v
is a row, the output will be A^Tv+bM
is a matrix, it is treated as a batch of rows, and for every row v
, A^Tv+b is performed, and then everything is put back into matrix form(N_batch, 4, 5, 6, 7)
. Is it true that the layer will output a batch of size N_batch
of (1, 1, 1, N_out)
-shaped vectors, everything shaped into a (N_batch, 4, 5, 6, N_out)
tensor?Upvotes: 0
Views: 3132
Reputation: 376
for 1 dimension, the input is vector with dim in_features
, output is out_features
. calculated as what you said
for 2 dimensions, the input is N_batch
vectors with dim in_features
, output is N_batch
vectors with dim out_features
. calculated as what you said
for 3 dimensions, the input is (N_batch, C, in_features)
, which is N_batch
matrices, each with C
rows of vectors with dim in_features
, output is N_batch
matrices, each with C
rows of vectors with dim out_features
.
If you feel it's hard to think of 3 dimensional case. One simple way is to flatten the shape to (N_batch * C, in_features)
, so the input becomes N_batch * C
rows of vectors with dim in_features
which is the same case as the two dimensional case. This flatten part involves no computation, just rearrange the input.
So in your case 3, its output is N_batch
of (3, 4, 5, 6, N_out)
vectors, or after rearrange its N_batch * 3 * 4 * 5 * 6
vectors with dim N_out
. your shape with all 1 dims are not correct, since there are only N_batch * N_out
elements in total.
If you dig into the internal C implementations of pytorch, you can find the matmul
implementation actually flatten the dimensions as I have described native matmul which is the exact function used by nn.Linear
Upvotes: 2