BenedictWilkins
BenedictWilkins

Reputation: 1263

multiply many matrices and many vectors pytorch

I am trying to multiply the following:

A batch of matrices N x M x D
A batch of vectors N x D x 1
To get a result: N x M x 1

as if I were doing N dot products on M x D D x 1.

I cant seem to find the correct function in PyTorch.

torch.bmm as far as I can tell only works for a batch of vectors and a single matrix. If I have to use torch.einsum then so be it but id rather not!

Upvotes: 1

Views: 176

Answers (1)

Quang Hoang
Quang Hoang

Reputation: 150825

It's pretty straightforward and intuitive with einsum:

torch.einsum('ijk, ikl->ijl', mats, vecs)

But your operation is just:

mats @ vecs

Upvotes: 1

Related Questions