Ecanyte
Ecanyte

Reputation: 81

Element-wise matrix vector multiplication

I have a tensor m which stores n 3 x 3 matrices with dim n x 3 x 3 and a tensor v with n 3x1 vectors and dim n x 3. How can I apply element-wise matrix-vector multiplication, i.e. multiply the i-th matrix with the i-th vector, to get an output tensor with dim n x 3?

Thanks for your help.

Upvotes: 0

Views: 1757

Answers (1)

Ivan
Ivan

Reputation: 40618

You want to perform a matrix multiplication operation (__matmul__) in a batch-wise manner. Intuitively you can use the batch-matmul operator torch.bmm. Keep in mind you first need to unsqueeze one dimension on v such that it becomes a 3D tensor. In this case indexing the last dimension with None as v[..., None] will provide a shape of (n, 3, 1).

With torch.bmm:

>>> torch.bmm(m, v[..., None])

As it turns out, torch.matmul handles this case out-of-the-box:

>>> torch.matmul(m, v[..., None]) # same as m@v[..., None]

If you want explicit control over the operation, you can go with torch.einsum:

>>> torch.einsum('bij,bj->bi', m, v)

Upvotes: 1

Related Questions