helloworld
helloworld

Reputation: 15

Multiplying a 4D tensor with a 3D tensor using numpy einsum or tensordot

I have a (2, 5, 3) 3D tensor and a (2, 5, 4, 3) 4D tensor and I am trying to compute a row-wise product between them in the following manner:

As an example, consider the following 3D and 4D tensor:

A = [[[ 0  1  2]
  [ 3  4  5]
  [ 6  7  8]
  [ 9 10 11]
  [12 13 14]]

 [[15 16 17]
  [18 19 20]
  [21 22 23]
  [24 25 26]
  [27 28 29]]] 

B = [[[[77 11 61]
   [55 98 50]
   [58 29 13]
   [56 48 52]]

  [[57  1 18]
   [ 7 52  3]
   [40 95 85]
   [18 13 27]]

  [[17 28 49]
   [48  2 62]
   [57  4  7]
   [86 62 98]]

  [[61 72 99]
   [36 49 71]
   [58 82 80]
   [54 45 90]]

  [[87 53 27]
   [43 90 25]
   [21 80 66]
   [ 2 52 98]]]


 [[[75 24 33]
   [87 14 82]
   [91 46 90]
   [79  6 13]]

  [[86 83 75]
   [93 33 36]
   [62  2 92]
   [91 45 12]]

  [[ 1  9 32]
   [41 77 76]
   [21 60 22]
   [44 59 79]]

  [[ 5 90 88]
   [31 54 93]
   [66 20 43]
   [69 15 79]]

  [[50 58 84]
   [78 35 92]
   [60 83 93]
   [44 31 46]]]]

The product tensor C has the same dimensions as the 4D tensor and is obtained by multiplying each row of the 3D tensor (A) with all the rows in each of the 3 x 4 sub-matrix in the 4D tensor B. So the first 3 x 4 sub-matrix in C is:

[0 1 2] * [[77 11 61]
           [55 98 50]
           [58 29 13]
           [56 48 52]]

= [[0 11 122]
   [0 98 100]
   [0 29 26 ]
   [0 48 104]]

And same for the other 9 rows to yield a (2 5 4 3) tensor.

I wonder if there is a way to achieve this using either tensordot or einsum in numpy? I have looked around various posts and also done some trials and errors but no luck. Would greatly appreciate if anyone can offer a solution or even some useful pointer!

Upvotes: 0

Views: 752

Answers (1)

hpaulj
hpaulj

Reputation: 231385

Add a dimension to A so it is (2,5,1,3):

A[:,:,None,:]*B 

With einsum, this should work, but I consider it to be overkill (there's not sum of products):

np.einsum('ijl,ijkl->ijkl',A,B)  

(I can't prove this with your arrays since B is too big to replicate with a copy-n-paste.)

Upvotes: 1

Related Questions