Lilla
Lilla

Reputation: 209

Working of nn.Linear with multiple dimensions

PyTorch's nn.Linear(in_features, out_features) accepts a tensor of size (N_batch, N_1, N_2, ..., N_end), where N_end = in_features. The output is a tensor of size (N_batch, N_1, N_2, ..., out_features).

It isn't very clear to me how it behaves in the following situations:

  1. If v is a row, the output will be A^Tv+b
  2. If M is a matrix, it is treated as a batch of rows, and for every row v, A^Tv+b is performed, and then everything is put back into matrix form
  3. What if the input tensor is of a higher rank? Say the input tensor has dimensions (N_batch, 4, 5, 6, 7). Is it true that the layer will output a batch of size N_batch of (1, 1, 1, N_out)-shaped vectors, everything shaped into a (N_batch, 4, 5, 6, N_out) tensor?

Upvotes: 0

Views: 3132

Answers (1)

ZNZNZ
ZNZNZ

Reputation: 376

for 1 dimension, the input is vector with dim in_features, output is out_features. calculated as what you said

for 2 dimensions, the input is N_batch vectors with dim in_features, output is N_batch vectors with dim out_features. calculated as what you said

for 3 dimensions, the input is (N_batch, C, in_features), which is N_batch matrices, each with C rows of vectors with dim in_features, output is N_batch matrices, each with C rows of vectors with dim out_features.

If you feel it's hard to think of 3 dimensional case. One simple way is to flatten the shape to (N_batch * C, in_features), so the input becomes N_batch * C rows of vectors with dim in_features which is the same case as the two dimensional case. This flatten part involves no computation, just rearrange the input.

So in your case 3, its output is N_batch of (3, 4, 5, 6, N_out) vectors, or after rearrange its N_batch * 3 * 4 * 5 * 6 vectors with dim N_out. your shape with all 1 dims are not correct, since there are only N_batch * N_out elements in total.

If you dig into the internal C implementations of pytorch, you can find the matmul implementation actually flatten the dimensions as I have described native matmul which is the exact function used by nn.Linear

Upvotes: 2

Related Questions