obar
obar

Reputation: 357

all pairwise dot product pytorch

Is there a built in function to calculate efficiently all pairwaise dot products of two tensors in Pytorch? e.g.
input - tensor A (shape NxD)
tensor B (shape NxD)

output - tensor C (shape NxN) such that C_i,j = torch.dot(A_i, B_j) ?

Upvotes: 0

Views: 3493

Answers (1)

Shai
Shai

Reputation: 114786

Isn't it simply

C = torch.mm(A, B.T)  # same as C = A @ B.T

BTW,
A very flexible tool for matrix/vector/tensor dot products is torch.einsum:

C = torch.einsum('id,jd->ij', A, B)

Upvotes: 2

Related Questions