Reputation: 1363
I am trying to expand the numpy "tensordot" such that things like:
K_ijklm = A_ki * B_jml
can be written in a clear way like this: K = mytensordot(A,B,[2,0],[1,4,3])
To my understanding, numpy's tensordot (with optional argument 0) would be able to do something like this: K_kijml = A_ki * B_jml
, i.e. keeping the order of the indexes. Therefore I would then have to do a number of np.swapaxes()
to obtain the matrix `K_ijklm', which in a complicated case can be an easy source of errors (potentially very hard to debug).
The problem is that my implementation is slow (10x slower than tensordot [EDIT: It is actually MUCH slower than that]), even when using numba. I was wondering if anyone would have some insight on what could be done to improve the performance of my algorithm.
import numpy as np
import numba as nb
import itertools
import timeit
@nb.jit()
def myproduct(dimN):
N=np.prod(dimN)
L=len(dimN)
Product=np.zeros((N,L),dtype=np.int32)
rn=0
for n in range(1,N):
for l in range(L):
if l==0:
rn=1
v=Product[n-1,L-1-l]+rn
rn = 0
if v == dimN[L-1-l]:
v = 0
rn = 1
Product[n,L-1-l]=v
return Product
@nb.jit()
def mytensordot(A,B,iA,iB):
iA,iB = np.array(iA,dtype=np.int32),np.array(iB,dtype=np.int32)
dimA,dimB = A.shape,B.shape
NdimA,NdimB=len(dimA),len(dimB)
if len(iA) != NdimA: raise ValueError("iA must be same size as dim A")
if len(iB) != NdimB: raise ValueError("iB must be same size as dim B")
NdimN = NdimA + NdimB
dimN=np.zeros(NdimN,dtype=np.int32)
dimN[iA]=dimA
dimN[iB]=dimB
Out=np.zeros(dimN)
indexes = myproduct(dimN)
for nidxs in indexes:
idxA = tuple(nidxs[iA])
idxB = tuple(nidxs[iB])
v=A[(idxA)]*B[(idxB)]
Out[tuple(nidxs)]=v
return Out
A=np.random.random((4,5,3))
B=np.random.random((6,4))
def runmytdot():
return mytensordot(A,B,[0,2,3],[1,4])
def runtensdot():
return np.tensordot(A,B,0).swapaxes(1,3).swapaxes(2,3)
print(np.all(runmytdot()==runtensdot()))
print(timeit.timeit(runmytdot,number=100))
print(timeit.timeit(runtensdot,number=100))
True
1.4962144780438393
0.003484356915578246
Upvotes: 4
Views: 1452
Reputation: 151027
You have run into a known issue. numpy.zeros
requires a tuple when creating a multidimensional array. If you pass something other than a tuple, it sometimes works, but that's only because numpy
is smart about converting the object into a tuple first.
The trouble is that numba
does not currently support conversion of arbitrary iterables into tuples. So this line fails when you try to compile it in nopython=True
mode. (A couple of others fail too, but this is the first.)
Out=np.zeros(dimN)
In theory you could call np.prod(dimN)
, create a flat array of zeros, and reshape it, but then you run into the very same problem: the reshape
method of numpy
arrays requires a tuple!
This is quite a vexing problem with numba
-- I had not encountered it before. I really doubt the solution I have found is the correct one, but it is a working solution that allows us to compile a version in nopython=True
mode.
The core idea is to avoid using tuples for indexing by directly implementing an indexer that follows the strides
of the array:
@nb.jit(nopython=True)
def index_arr(a, ix_arr):
strides = np.array(a.strides) / a.itemsize
ix = int((ix_arr * strides).sum())
return a.ravel()[ix]
@nb.jit(nopython=True)
def index_set_arr(a, ix_arr, val):
strides = np.array(a.strides) / a.itemsize
ix = int((ix_arr * strides).sum())
a.ravel()[ix] = val
This allows us to get and set values without needing a tuple.
We can also avoid using reshape
by passing the output buffer into the jitted function, and wrapping that function in a helper:
@nb.jit() # We can't use nopython mode here...
def mytensordot(A, B, iA, iB):
iA, iB = np.array(iA, dtype=np.int32), np.array(iB, dtype=np.int32)
dimA, dimB = A.shape, B.shape
NdimA, NdimB = len(dimA), len(dimB)
if len(iA) != NdimA:
raise ValueError("iA must be same size as dim A")
if len(iB) != NdimB:
raise ValueError("iB must be same size as dim B")
NdimN = NdimA + NdimB
dimN = np.zeros(NdimN, dtype=np.int32)
dimN[iA] = dimA
dimN[iB] = dimB
Out = np.zeros(dimN)
return mytensordot_jit(A, B, iA, iB, dimN, Out)
Since the helper contains no loops, it adds some overhead, but the overhead is pretty trivial. Here's the final jitted function:
@nb.jit(nopython=True)
def mytensordot_jit(A, B, iA, iB, dimN, Out):
for i in range(np.prod(dimN)):
nidxs = int_to_idx(i, dimN)
a = index_arr(A, nidxs[iA])
b = index_arr(B, nidxs[iB])
index_set_arr(Out, nidxs, a * b)
return Out
Unfortunately, this does not wind up generating as much of a speedup as we might like. On smaller arrays it's about 5x slower than tensordot
; on larger arrays it's still 50x slower. (But at least it's not 1000x slower!) This is not too surprising in retrospect, since dot
and tensordot
are both using BLAS under the hood, as @hpaulj reminds us.
After finishing this code, I saw that einsum
has solved your real problem -- nice!
But the underlying issue that your original question points to -- that indexing with arbitrary-length tuples is not possible in jitted code -- is still a frustration. So hopefully this will be useful to someone else!
Upvotes: 4
Reputation: 231425
tensordot
with scalar axes values can be obscure. I explored it in
How does numpy.tensordot function works step-by-step?
There I deduced that np.tensordot(A, B, axes=0)
is equivalent using axes=[[], []]
.
In [757]: A=np.random.random((4,5,3))
...: B=np.random.random((6,4))
In [758]: np.tensordot(A,B,0).shape
Out[758]: (4, 5, 3, 6, 4)
In [759]: np.tensordot(A,B,[[],[]]).shape
Out[759]: (4, 5, 3, 6, 4)
That in turn is equivalent to calling dot
with a new size 1 sum-of-products dimenson:
In [762]: np.dot(A[...,None],B[...,None,:]).shape
Out[762]: (4, 5, 3, 6, 4)
(4,5,3,1) * (6,1,4) # the 1 is the last of A and 2nd to the last of B
dot
is fast, using BLAS (or equivalent) code. Swapping axes and reshaping is also relatively fast.
einsum
gives us a lot of control over axes
replicating the above products:
In [768]: np.einsum('jml,ki->jmlki',A,B).shape
Out[768]: (4, 5, 3, 6, 4)
and with swapping:
In [769]: np.einsum('jml,ki->ijklm',A,B).shape
Out[769]: (4, 4, 6, 3, 5)
A minor point - the double swap can be written as one transpose:
.swapaxes(1,3).swapaxes(2,3)
.transpose(0,3,1,2,4)
Upvotes: 2