user541686
user541686

Reputation: 210755

Compute numpy.inner() over first (instead of last) axis

I'm trying to make a function like numpy.inner, but which sums over the first axis of both arrays instead of the last axis. Currently I'm using tensordot with rollaxis:

def inner1(a, b):
    return numpy.tensordot(numpy.rollaxis(a, 0, len(a.shape)), b, 1)

but I'm wondering: is there a better way? Perhaps one that doesn't require me to roll the axes?

I feel like einsum should make this possible, but I'm not sure how to use it here.
It seems to require me to hard-code the dimensionality of a and b when I specify the subscripts string, which I can't really do here because there is no particular requirement on the input dimensionality.

(Note: I am aware that there are performance implications to summing over the first axis instead of the last, but I'm ignoring them here.)

Upvotes: 0

Views: 147

Answers (2)

hpaulj
hpaulj

Reputation: 231665

This isn't as pretty as the tensordot solution, but you can construct the einsum string from ndim of the inputs:

ll = 'abcdefghijklmnopqrstuvw'
astr = ll[0]+ll[1:a.ndim]+','+ll[0]+ll[a.ndim:a.ndim+b.ndim-1]
np.einsum(astr,a,b)

np.einsum lets you specify axes as lists rather than the string

np.einsum(a, [0]+range(1,a.ndim), b, [0]+range(a.ndim,a.ndim+b.ndim-1))

For a pair of 3d and 2d arrays, these produce:

 np.einsum('abc,ad', a, b)
 np.einsum(a, [0,1,2], b, [0,3])

'...' doesn't work here because that implies repeated axes (to the extent possible), where as you want unique axes (except for the 1st).

While messier to write, the einsum solution is faster than the tensordot one (3x faster for small test arrays).


Another option with einsum is to reshape the arrays, reducing the 'remaining' dimensions down to one. This adds a bit of time to the calculation, but not a lot:

np.einsum('ij,ik',a.reshape(a.shape[0],-1), b.reshape(a.shape[0],-1)).reshape(a.shape[1:]+b.shape[1:])

Upvotes: 1

farenorth
farenorth

Reputation: 10791

I think what you want is np.tensordot(a, b, (0, 0)).

Upvotes: 2

Related Questions