Neurobro
Neurobro

Reputation: 262

Speed up einsum for sparse tensors

I would like to perform the following batch of matrix multiplications

proj = torch.einsum('abi,aic->abc', A, B)

where A is an nxnxd tensor and B is an nxdxd tensor.

When n gets large ~50k, this operation becomes very slow.

However, A is actually sparse in the first two dimensions, i.e., it could actually be written as a set of indices (i,j) and a corresponding set of 1xd vectors.

Could someone help me how to speed this computation up?

Upvotes: 1

Views: 135

Answers (0)

Related Questions