Reputation: 543
In the annotated transformer's implementation of multi-head attention, three tensors (query, key, value) are all passed to a nn.Linear(d_model, d_model)
:
# some class definition ...
self.linears = clones(nn.Linear(d_model, d_model), 4) # deep-copied list of nn.Linear-modules concatenated via nn.ModuleList
# more code ...
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
My question: what happens at lin(x)
, when an instance of nn.Linear()
is called on the tuple (query, key, value)
? Is the tuple somehow concatenated to a tensor? If so, how - on which dimension are the tensors concatenated?
Upvotes: 0
Views: 165
Reputation: 126
self.linears = clones(nn.Linear(d_model, d_model), 4) # deep-copied list of nn.Linear-modules concatenated via nn.ModuleList
# more code ...
query, key, value = [
lin(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)
for lin, x in zip(self.linears, (query, key, value))
]
Actually, the nn.Linear
does not process input as a tuple of a Q,K,V
.
In your code, the result similar like this
out_Q = self.linears[0](Q)
out_K = self.linears[1](K)
out_V = self.linears[2](V)
When you use zip(iterable A, iterable B) So you will get the pairs (A[0], B[0]) (A[1], B[1]) ,... independently
Or more specific
query = self.linears[0](query)
key = self.linears[1](key)
value = self.linears[2](value)
Upvotes: 1