clueless
clueless

Reputation: 251

Clarification on einsum equation

I came across some code on Huggingface (in a self-attention module) that uses torch.einsum, which I'm not too familiar with and would like some help interpreting. I've looked through this list of basic operations and their implementations in NumPy/PyTorch. The inputs are a 4D tensor and a 3D tensor.

This is the (explicit) einsum string:

'bhld,lrd->bhlr'

(Another einsum string used is similar:

'bhrd,lrd->bhlr')

What does this mean/how else could this be implemented without using einsum? E.g., the second tensor must be transposed so that d is the first dimension.

Upvotes: 0

Views: 255

Answers (1)

hpaulj
hpaulj

Reputation: 231385

'bhld,lrd->bhlr'

First arg is 4d, 2nd is 3d, result is 4d

'bh' passes thru unchanged. 'r' also. 'ld' dimensions are matched, with multiplication, and sum of products on 'd'.

It terms of a broadcasted sum of products I think the equivalent is (not tested)

(A[:,:,:,None,:] * B[None, None, :,:,:]).sum(axis=-1)

With matmul, put 'd' of B 2nd to the last. Ensure that the first 3 dimensions broadcast.

A[:,:,:,None,:] @ B.transpose(0,2,1)[None, None, :,:,:]

Same sum on 'd', 'r' is shared among all 3

'bhrd,lrd->bhlr'

Upvotes: 1

Related Questions