Lukas
Lukas

Reputation: 543

How does an instance of pytorch's `nn.Linear()` process a tuple of tensors?

In the annotated transformer's implementation of multi-head attention, three tensors (query, key, value) are all passed to a nn.Linear(d_model, d_model):

# some class definition ...
self.linears = clones(nn.Linear(d_model, d_model), 4) # deep-copied list of nn.Linear-modules concatenated via nn.ModuleList
# more code ...
query, key, value = [
  lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
  for lin, x in zip(self.linears, (query, key, value))
]

My question: what happens at lin(x), when an instance of nn.Linear() is called on the tuple (query, key, value)? Is the tuple somehow concatenated to a tensor? If so, how - on which dimension are the tensors concatenated?

Upvotes: 0

Views: 165

Answers (1)

dungxibo123
dungxibo123

Reputation: 126

self.linears = clones(nn.Linear(d_model, d_model), 4) # deep-copied list of nn.Linear-modules concatenated via nn.ModuleList
# more code ...
query, key, value = [
  lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
  for lin, x in zip(self.linears, (query, key, value))
]

Actually, the nn.Linear does not process input as a tuple of a Q,K,V. In your code, the result similar like this

out_Q = self.linears[0](Q)
out_K = self.linears[1](K)
out_V = self.linears[2](V)

When you use zip(iterable A, iterable B) So you will get the pairs (A[0], B[0]) (A[1], B[1]) ,... independently

Or more specific

query = self.linears[0](query)
key = self.linears[1](key)
value = self.linears[2](value)

Upvotes: 1

Related Questions