galah92
galah92

Reputation: 3991

np.einsum performance of 4 matrix multiplications

Given the following 3 matrices:

M = np.arange(35 * 37 * 59).reshape([35, 37, 59])
A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
C = np.arange(59 * 27).reshape([59, 27])

I'm using einsum to compute:

D1 = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize=True);

But I found it to be much less performant then:

tmp = np.einsum('xyf,xtf->tfy', A, M, optimize=True)
tmp = np.einsum('ytpf,yft->ftp', B, tmp, optimize=True)
D2 = np.einsum('fr,ftp->tpr', C, tmp, optimize=True)

And I can't understand why.
Overall I'm trying to optimize this piece of code as much as I can. I've read about the np.tensordot function but I can't seem to figure out how to utilize it for the given computation.

Upvotes: 2

Views: 1598

Answers (2)

Paul Fackler
Paul Fackler

Reputation: 21

Although it is true that a greedy algorithm (there are several) may not find the optimal ordering in this case, this does not have anything to do with the puzzle here. When you do the D2 approach you have determined the order of operations which in this case is (((A,M),B),C) or equivalently (((M,A),B),C). This happens to be the optimal path. The 3 optimize=True statements are not needed and are ignored because there is no optimization used when there are 2 factors. The slowdown of the D1 method is due to the need to find the optimal ordering of a 4 array operation. If you first found the path and then passed it, with the 4 arrays, to einsum using Optimize=path my guess is that the two methods would be essentially equivalent. Thus the slowdown is due to the optimization step for D1. Although I am not sure how the optimal ordering is found, based on unpublished work I've done, this task will generally have O(3^n) worst case behavior where n is the number of arrays.

Upvotes: 1

Daniel
Daniel

Reputation: 19547

Looks like you stumbled onto a case where the greedy path gives a non-optimal scaling.

>>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="greedy");
>>> print(desc)
  Complete contraction:  xyf,xtf,ytpf,fr->tpr
         Naive scaling:  6
     Optimized scaling:  5
      Naive FLOP count:  3.219e+10
  Optimized FLOP count:  4.165e+08
   Theoretical speedup:  77.299
  Largest intermediate:  5.371e+06 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   5              ytpf,xyf->xptf                         xtf,fr,xptf->tpr
   4               xptf,xtf->ptf                              fr,ptf->tpr
   4                 ptf,fr->tpr                                 tpr->tpr

>>> path, desc = np.einsum_path('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal");
>>> print(desc)
  Complete contraction:  xyf,xtf,ytpf,fr->tpr
         Naive scaling:  6
     Optimized scaling:  4
      Naive FLOP count:  3.219e+10
  Optimized FLOP count:  2.744e+07
   Theoretical speedup:  1173.425
  Largest intermediate:  1.535e+05 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
   4                xtf,xyf->ytf                         ytpf,fr,ytf->tpr
   4               ytf,ytpf->ptf                              fr,ptf->tpr
   4                 ptf,fr->tpr                                 tpr->tpr

Using np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C, optimize="optimal") should have you running at peak performance. I can look into this edge to see if greedy can nab it.

Upvotes: 4

Related Questions