Daniellah
Daniellah

Reputation: 71

Multiplying matrices (dot product) along an axis with Numpy (without loops)

I am working with Numpy on an image processing problem, and I am trying to avoid loops and do the following:

I have a matrix M of dims NxNxKxK (which is a matrix NxN of matrices KxK), and for each row, I wish to multiply (dot product) all the N matrices (KxK) in the row. So that if I do this on the full M (all of the rows) I get a vector V (Nx1) of matrices (KxK) where V[i] holds the dot product of M[i,0]xM[i,1]x...xM[i,N-1].

I implemented a solution to this problem using loops, and I can't figure out a way to do this without loops.

Implementation (with loops):

    a = np.array([[1,1,1], [1,1,1], [1,1,1]])
    mat = np.array([[a,a,a,a], [a*2,a*2,a*2,a*2], [a*3,a*3,a*3,a*3],
                    [a*4,a*4,a*4,a*4]])  # the original matrix
    N, N, k, k = mat.shape
    result = np.ones((N, k, k))  # resulting matrix
    for i in range(N):
        k = functools.reduce(np.dot, mat[i,:])
        result[i,:] = k
    print(result)

Upvotes: 7

Views: 1734

Answers (1)

user6655984
user6655984

Reputation:

The following uses reduce but not a loop over N:

mat = mat.swapaxes(0, 1)
result = functools.reduce(lambda a, b: np.einsum('ijk,ikl->ijl', a, b), mat[:])

The einsum notation 'jk,kl->jl' expresses matrix multiplication, and the index i indicates it should be done over each value of 1st index. The first index of mat[0] or mat[1] is actually the second index of mat (column index), so as written, multiplication takes place in each column of mat. You wanted it to be done in each row, hence the use of swapaxes.

Whether this is faster or slower than for-loop version depends on the relative size of N and k. The np.dot method is highly optimized, but if the loop over N is very long, einsum might win. Some %timeit results:

  • With N=100, k=2, for-loop version takes 7.5 ms, einsum version takes 4.31 ms.
  • With N=100, k=20, for-loop version takes 27.3 ms, einsum version takes 153 ms.

So, there is a modest gain in specific cases, with major losses in most other cases. But you did not ask for an "efficient" solution, you asked for one "without loops", so here it is ("without loops" != "faster"). As Divakar suggested in a comment, you are probably better off keeping the code as is.

Upvotes: 2

Related Questions