Miguel
Miguel

Reputation: 1363

Algorithm for tensordot implemented in numba is much slower than numpy's

I am trying to expand the numpy "tensordot" such that things like: K_ijklm = A_ki * B_jml can be written in a clear way like this: K = mytensordot(A,B,[2,0],[1,4,3])

To my understanding, numpy's tensordot (with optional argument 0) would be able to do something like this: K_kijml = A_ki * B_jml, i.e. keeping the order of the indexes. Therefore I would then have to do a number of np.swapaxes() to obtain the matrix `K_ijklm', which in a complicated case can be an easy source of errors (potentially very hard to debug).

The problem is that my implementation is slow (10x slower than tensordot [EDIT: It is actually MUCH slower than that]), even when using numba. I was wondering if anyone would have some insight on what could be done to improve the performance of my algorithm.

MWE

import numpy as np
import numba as nb
import itertools
import timeit

@nb.jit()
def myproduct(dimN):
    N=np.prod(dimN)
    L=len(dimN)
    Product=np.zeros((N,L),dtype=np.int32)
    rn=0
    for n in range(1,N):
        for l in range(L):
            if l==0:
                rn=1
            v=Product[n-1,L-1-l]+rn
            rn = 0
            if v == dimN[L-1-l]:
                v = 0
                rn = 1
            Product[n,L-1-l]=v
    return Product

@nb.jit()
def mytensordot(A,B,iA,iB):
    iA,iB = np.array(iA,dtype=np.int32),np.array(iB,dtype=np.int32)
    dimA,dimB = A.shape,B.shape
    NdimA,NdimB=len(dimA),len(dimB)

    if len(iA) != NdimA: raise ValueError("iA must be same size as dim A")
    if len(iB) != NdimB: raise ValueError("iB must be same size as dim B")

    NdimN = NdimA + NdimB
    dimN=np.zeros(NdimN,dtype=np.int32)
    dimN[iA]=dimA
    dimN[iB]=dimB
    Out=np.zeros(dimN)
    indexes = myproduct(dimN)

    for nidxs in indexes:
        idxA = tuple(nidxs[iA])
        idxB = tuple(nidxs[iB])
        v=A[(idxA)]*B[(idxB)]
        Out[tuple(nidxs)]=v
    return Out



A=np.random.random((4,5,3))
B=np.random.random((6,4))

def runmytdot():
    return mytensordot(A,B,[0,2,3],[1,4])
def runtensdot():
    return np.tensordot(A,B,0).swapaxes(1,3).swapaxes(2,3)


print(np.all(runmytdot()==runtensdot()))
print(timeit.timeit(runmytdot,number=100))
print(timeit.timeit(runtensdot,number=100))

Result:

True
1.4962144780438393
0.003484356915578246

Upvotes: 4

Views: 1452

Answers (2)

senderle
senderle

Reputation: 151027

You have run into a known issue. numpy.zeros requires a tuple when creating a multidimensional array. If you pass something other than a tuple, it sometimes works, but that's only because numpy is smart about converting the object into a tuple first.

The trouble is that numba does not currently support conversion of arbitrary iterables into tuples. So this line fails when you try to compile it in nopython=True mode. (A couple of others fail too, but this is the first.)

Out=np.zeros(dimN)

In theory you could call np.prod(dimN), create a flat array of zeros, and reshape it, but then you run into the very same problem: the reshape method of numpy arrays requires a tuple!

This is quite a vexing problem with numba -- I had not encountered it before. I really doubt the solution I have found is the correct one, but it is a working solution that allows us to compile a version in nopython=True mode.

The core idea is to avoid using tuples for indexing by directly implementing an indexer that follows the strides of the array:

@nb.jit(nopython=True)
def index_arr(a, ix_arr):
    strides = np.array(a.strides) / a.itemsize
    ix = int((ix_arr * strides).sum())
    return a.ravel()[ix]

@nb.jit(nopython=True)
def index_set_arr(a, ix_arr, val):
    strides = np.array(a.strides) / a.itemsize
    ix = int((ix_arr * strides).sum())
    a.ravel()[ix] = val

This allows us to get and set values without needing a tuple.

We can also avoid using reshape by passing the output buffer into the jitted function, and wrapping that function in a helper:

@nb.jit()  # We can't use nopython mode here...
def mytensordot(A, B, iA, iB):
    iA, iB = np.array(iA, dtype=np.int32), np.array(iB, dtype=np.int32)
    dimA, dimB = A.shape, B.shape
    NdimA, NdimB = len(dimA), len(dimB)

    if len(iA) != NdimA:
        raise ValueError("iA must be same size as dim A")
    if len(iB) != NdimB:
        raise ValueError("iB must be same size as dim B")

    NdimN = NdimA + NdimB
    dimN = np.zeros(NdimN, dtype=np.int32)
    dimN[iA] = dimA
    dimN[iB] = dimB
    Out = np.zeros(dimN)
    return mytensordot_jit(A, B, iA, iB, dimN, Out)

Since the helper contains no loops, it adds some overhead, but the overhead is pretty trivial. Here's the final jitted function:

@nb.jit(nopython=True)
def mytensordot_jit(A, B, iA, iB, dimN, Out):
    for i in range(np.prod(dimN)):
        nidxs = int_to_idx(i, dimN)
        a = index_arr(A, nidxs[iA])
        b = index_arr(B, nidxs[iB])
        index_set_arr(Out, nidxs, a * b)
    return Out

Unfortunately, this does not wind up generating as much of a speedup as we might like. On smaller arrays it's about 5x slower than tensordot; on larger arrays it's still 50x slower. (But at least it's not 1000x slower!) This is not too surprising in retrospect, since dot and tensordot are both using BLAS under the hood, as @hpaulj reminds us.

After finishing this code, I saw that einsum has solved your real problem -- nice!

But the underlying issue that your original question points to -- that indexing with arbitrary-length tuples is not possible in jitted code -- is still a frustration. So hopefully this will be useful to someone else!

Upvotes: 4

hpaulj
hpaulj

Reputation: 231425

tensordot with scalar axes values can be obscure. I explored it in

How does numpy.tensordot function works step-by-step?

There I deduced that np.tensordot(A, B, axes=0) is equivalent using axes=[[], []].

In [757]: A=np.random.random((4,5,3))
     ...: B=np.random.random((6,4))

In [758]: np.tensordot(A,B,0).shape
Out[758]: (4, 5, 3, 6, 4)
In [759]: np.tensordot(A,B,[[],[]]).shape
Out[759]: (4, 5, 3, 6, 4)

That in turn is equivalent to calling dot with a new size 1 sum-of-products dimenson:

In [762]: np.dot(A[...,None],B[...,None,:]).shape
Out[762]: (4, 5, 3, 6, 4)

(4,5,3,1) * (6,1,4)   # the 1 is the last of A and 2nd to the last of B

dot is fast, using BLAS (or equivalent) code. Swapping axes and reshaping is also relatively fast.

einsum gives us a lot of control over axes

replicating the above products:

In [768]: np.einsum('jml,ki->jmlki',A,B).shape
Out[768]: (4, 5, 3, 6, 4)

and with swapping:

In [769]: np.einsum('jml,ki->ijklm',A,B).shape
Out[769]: (4, 4, 6, 3, 5)

A minor point - the double swap can be written as one transpose:

.swapaxes(1,3).swapaxes(2,3)
.transpose(0,3,1,2,4)

Upvotes: 2

Related Questions