Reputation: 357
Is there a built in function to calculate efficiently all pairwaise dot products of two tensors in Pytorch?
e.g.
input - tensor A
(shape N
xD
)
tensor B
(shape N
xD
)
output - tensor C
(shape N
xN
) such that C_i,j = torch.dot(A_i, B_j)
?
Upvotes: 0
Views: 3493
Reputation: 114786
Isn't it simply
C = torch.mm(A, B.T) # same as C = A @ B.T
BTW,
A very flexible tool for matrix/vector/tensor dot products is torch.einsum
:
C = torch.einsum('id,jd->ij', A, B)
Upvotes: 2