Joe
Joe

Reputation: 913

NumPy Tensordot axes=2

I know there are many questions about tensordot, and I've skimmed some of the 15 page mini-book answers that people I'm sure spent hours making, but I haven't found an explanation of what axes=2 does.

This made me think that np.tensordot(b,c,axes=2) == np.sum(b * c), but as an array:

b = np.array([[1,10],[100,1000]])
c = np.array([[2,3],[5,7]])
np.tensordot(b,c,axes=2)
Out: array(7532)

But then this failed:

a = np.arange(30).reshape((2,3,5))
np.tensordot(a,a,axes=2)

If anyone can provide a short, concise explanation of np.tensordot(x,y,axes=2), and only axes=2, then I would gladly accept it.

Upvotes: -1

Views: 623

Answers (1)

hpaulj
hpaulj

Reputation: 231385

In [70]: a = np.arange(24).reshape(2,3,4)
In [71]: np.tensordot(a,a,axes=2)
Traceback (most recent call last):
  File "<ipython-input-71-dbe04e46db70>", line 1, in <module>
    np.tensordot(a,a,axes=2)
  File "<__array_function__ internals>", line 5, in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

In my previous post I deduced that axis=2 translates to axes=([-2,-1],[0,1])

How does numpy.tensordot function works step-by-step?

In [72]: np.tensordot(a,a,axes=([-2,-1],[0,1]))
Traceback (most recent call last):
  File "<ipython-input-72-efdbfe6ff0d3>", line 1, in <module>
    np.tensordot(a,a,axes=([-2,-1],[0,1]))
  File "<__array_function__ internals>", line 5, in tensordot
  File "/usr/local/lib/python3.8/dist-packages/numpy/core/numeric.py", line 1116, in tensordot
    raise ValueError("shape-mismatch for sum")
ValueError: shape-mismatch for sum

So that's trying to do a double axis reduction on the last 2 dimensions of the first a, and the first 2 dimensions of the second a. With this a that's a dimensions mismatch. Evidently this axes was intended for 2d arrays, without much thought given to 3d ones. It is not a 3 axis reduction.

These single digit axes values are something that some developer thought would be convenient, but that does not mean they were rigorously thought out or tested.

The tuple axes gives you more control:

In [74]: np.tensordot(a,a,axes=[(0,1,2),(0,1,2)])
Out[74]: array(4324)
In [75]: np.tensordot(a,a,axes=[(0,1),(0,1)])
Out[75]: 
array([[ 880,  940, 1000, 1060],
       [ 940, 1006, 1072, 1138],
       [1000, 1072, 1144, 1216],
       [1060, 1138, 1216, 1294]])

Upvotes: 0

Related Questions