Reputation: 262
I would like to perform the following batch of matrix multiplications
proj = torch.einsum('abi,aic->abc', A, B)
where A
is an nxnxd tensor and B
is an nxdxd tensor.
When n gets large ~50k, this operation becomes very slow.
However, A is actually sparse in the first two dimensions, i.e., it could actually be written as a set of indices (i,j) and a corresponding set of 1xd vectors.
Could someone help me how to speed this computation up?
Upvotes: 1
Views: 135