Zaya
Zaya

Reputation: 336

Efficiently find the dot product of two lists of vectors stored as PyTorch tensors & preserve backprop

Suppose I had tensors X and Y which are both (batch_size, d) dimensional. I would like to find the (batch_size x 1) tensor resulting from [X[0]@Y[0].T, X[1]@Y[1].T, ...]

There are two ways I can think of doing this, neither of which are particularly efficient.

Way 1

product = torch.eye(batch_size) * [email protected]
product = torch.sum(product, dim=1)

This works, but for large matrices there are a LOT of wasted computations

Way 2

product = torch.cat(
    [ X[i]@Y[i].T for i in X.size(0) ],
    dim=0
)

This is good in that no cycles are wasted, but it won't leverage any of the built-in parallelism torch offers.

I'm aware that numpy has a method that will do this, but converting the tensors to np arrays will destroy the chain of backpropagation, and this is for a neural net, so that's not an option.

Am I missing an obvious built in torch method, or am I stuck with these two options?

Upvotes: 1

Views: 2089

Answers (2)

Sebastien
Sebastien

Reputation: 1476

You could also do it with einsum:

product = torch.einsum("in,in->i", X, Y)

Upvotes: 1

swag2198
swag2198

Reputation: 2696

One way would be this. Simply use broadcasted matrix multiplication over reshaped row vectors of X and column vectors of Y.

import torch
X = X.reshape(batch_size, 1, d)
Y = Y.reshape(batch_size, d, 1)
product = torch.matmul(X, Y).squeeze(1)

The output product will have the required shape of (batch_size, 1) with the desired result.

Upvotes: 3

Related Questions