Inkyu Kim
Inkyu Kim

Reputation: 155

Pytorch - do matrix multiplications from slices of 2 tensors

If there are 2 tensors of the following sizes.

A = [N x L x T]

B = [N x T x K]

Then I would like to do a matrix multiplication of slices from the 2 tensors. like below.

matmul_slice = A[0,:,:] @ B[0,:,:] = [L x T] @ [T x K] = [L x K]

Then I would like to do it N times along the dimension = 0. So that I end up with the final matrix with size [N,L,K]

I do not want to use loop over N since it slows down the computation. I have been playing around with torch.matmul and einsum, but I cannot get the correct answer.

How can I achieve this in a compact way?

Upvotes: 0

Views: 135

Answers (2)

Ivan
Ivan

Reputation: 40738

The matmul operator works like this: (*, i, j) @ (*, j, k) = (*, i, k).
So in your case, no need to transpose A, simply A@B.

If you prefer, you can use torch.einsum to show the explicit expression:

torch.einsum('bij,bjk->bik', A, B)

Note: torch.bmm works the same way as matmul but does not broadcast: A.bmm(B).

Upvotes: 1

MinhNH
MinhNH

Reputation: 564

torch.bmm is what your need, although torch.matmul should be equivalent in your case. I think you should recheck your computation.

Upvotes: 1

Related Questions