Reputation: 22260
I'm trying to perform a matrix multiplication with frontal slices of a 3D tensor, shown below. If X.shape == (N, N)
, and Y.shape == (N, N, Y)
, the resulting tensor should be of shape (N, N, Y)
.
What's the proper np.tensordot
syntax to achieve this?
I'm trying to limit myself to np.tensordot
, and not np.einsum
, because I want to later translate this solution to Theano. Unfortunately, Theano does not have np.einsum
implemented yet.
Graphics adapted from this paper about tensor multiplication. The non-tensordot answer is equivalent to the following
tensor = np.random.rand(3, 3, 2)
X = np.random.rand(3, 3)
output = np.zeros((3, 3, 2))
output[:, :, 0] = X.dot(tensor[:, :, 0])
output[:, :, 1] = X.dot(tensor[:, :, 1])
Upvotes: 1
Views: 1049
Reputation: 221554
The reduction is along axis=1
for X
and axis=0
for tensor
, thus np.tensordot
based solution would be -
np.tensordot(X,tensor, axes=([1],[0]))
Explanation :
Let's take your iterative solution for explanation and in it the first iteration :
output[:, :, 0] = X.dot(tensor[:, :, 0])
In the dot product, the first input is X
, whose shape is (N x N)
and the second input is tensor[:, :, 0]
, which is the first slice along the last axis and its shape is (N x N
). That dot product is causing reduction along the second axis of X
, i.e. axis=1
and along the first axis, i.e. axis=0
of tensor[:, :, 0]
, which also happens to be the first axis of the entire array tensor
. Now, this continues across all iterations. Therefore, even in the big picture, we need to do the same : Reduce/ Lose axis=1
in X
and axis=0
in tensor, just like we did!
Integrating @hlin117's answer
np.tensordot(X,tensor, axes=([1],[0]))
Timing:
>>> N = 200
>>> tensor = np.random.rand(N, N, 30)
>>> X = np.random.rand(N, N)
>>>
>>> %timeit np.tensordot(X, tensor, axes=([1], [0]))
100 loops, best of 3: 14.7 ms per loop
>>> %timeit np.tensordot(X, tensor, axes=1)
100 loops, best of 3: 15.2 ms per loop
Upvotes: 1
Reputation: 22260
Looks like the above is equivalent to the following:
np.tensordot(X, tensor, axes=1)
axes=1
, because (if the axes
argument is a scalar) N
should be the last axis of the first argument, and N
should be the first axis of the second argument.
Upvotes: 1