galah92
galah92

Reputation: 4001

Optimizing tensor multiplications

I've got a real-time image processing program I'm trying to optimize, and it all boils down to matrix multiplications. Consider 3 tensors I'm calculating in the initialization stage:

  1. A = np.arange(35 * 51 * 59).reshape([35, 51, 59])
  2. B = np.arange(37 * 51 * 51 * 59).reshape([37, 51, 51, 59])
  3. C = np.arange(59 * 27).reshape([59, 27])

Each frame, I'm getting a new data in the form of a fourth tensor:

Currently, I'm calculating D = np.einsum('xyf,xtf,ytpf,fr->tpr', M, A, B, C), where D is my desired result, and it's the major bottleneck of the program. There are two directions I'm trying to follow in order to optimize it.

First I tried coming up with a tensor T, a function of A, B, C, D that I can pre-calculate, and then it'll all boil to D = np.tensordot(M, T, axes=..). I wasn't successful. I spent a lot of time on it, is it even possible at all?

Moreover, the program itself is written in MATLAB. As it doesn't have a built-in tensor multiplication function (einsum or tensordot equivilent), I'm currently using the tprod toolbox, and doing:

temp1 = etprod('dcb', A, 'abc', M, 'adc');
temp2 = etprod('dbc', B, 'abcd', temp1, 'adb');
D = etprod('cdb', C, 'ab', temp2, 'acd');

As the default dot product function in MATLAB (for 2D matrices) is much faster then etprod, I though about reshaping A, B, C, D to 2D arrays in a way that I will able to multiple 2D matrices using the default function, without hand-written for loops. I wasn't successful with that either.

Any thoughts? thanks!

Upvotes: 1

Views: 203

Answers (1)

Paul Fackler
Paul Fackler

Reputation: 21

If this operation is done many times with different values of M we could define

D0 = np.einsum('xft,fr->tpr',A, B, C)

The whole operation could be broken into binary steps:

D0=np.einsum('xtf,ytpf->xyptf',A,B)
D0=np.einsum('xyptf,fr->xyftpr',D0,C)
D=np.einsum('tprxfy,xfy->tpr',D0,M)

The final operation uses D0 and M and can be coded as a matrix vector operation. In Matlab it would be

D=reshape(D0.[],numel(M))*M(:);

which could then be reordered as desired. We could write this order as (((A,B),C),M)

It might be better, however, to use ((M,C),A,B)

D=np.einsum('xyf,fr->xyfr',M,C)
D0=np.einsum('xyfr,xtf->ytfr',D,A)
D=np.einsum('ytfr,ytpf->tpr',D,B)

This ordering of operations has intermediate arrays with only 4 indices rather than one with 6. If each operation is much faster than the single one this may be an advantage.

Upvotes: 1

Related Questions