hlin117
hlin117

Reputation: 22260

Numpy tensor: Tensordot over frontal slices of tensor

I'm trying to perform a matrix multiplication with frontal slices of a 3D tensor, shown below. If X.shape == (N, N), and Y.shape == (N, N, Y), the resulting tensor should be of shape (N, N, Y).

What's the proper np.tensordot syntax to achieve this?

I'm trying to limit myself to np.tensordot, and not np.einsum, because I want to later translate this solution to Theano. Unfortunately, Theano does not have np.einsum implemented yet.

enter image description here

Graphics adapted from this paper about tensor multiplication. The non-tensordot answer is equivalent to the following

tensor = np.random.rand(3, 3, 2)
X = np.random.rand(3, 3)

output = np.zeros((3, 3, 2))
output[:, :, 0] = X.dot(tensor[:, :, 0])
output[:, :, 1] = X.dot(tensor[:, :, 1])

Upvotes: 1

Views: 1049

Answers (2)

Divakar
Divakar

Reputation: 221554

The reduction is along axis=1 for X and axis=0 for tensor, thus np.tensordot based solution would be -

np.tensordot(X,tensor, axes=([1],[0]))

Explanation :

Let's take your iterative solution for explanation and in it the first iteration :

output[:, :, 0] = X.dot(tensor[:, :, 0])

In the dot product, the first input is X, whose shape is (N x N) and the second input is tensor[:, :, 0], which is the first slice along the last axis and its shape is (N x N). That dot product is causing reduction along the second axis of X, i.e. axis=1 and along the first axis, i.e. axis=0 of tensor[:, :, 0], which also happens to be the first axis of the entire array tensor. Now, this continues across all iterations. Therefore, even in the big picture, we need to do the same : Reduce/ Lose axis=1 in X and axis=0 in tensor, just like we did!


Integrating @hlin117's answer

np.tensordot(X,tensor, axes=([1],[0]))

Timing:

>>> N = 200
>>> tensor = np.random.rand(N, N, 30)
>>> X = np.random.rand(N, N)
>>> 
>>> %timeit np.tensordot(X, tensor, axes=([1], [0]))
100 loops, best of 3: 14.7 ms per loop
>>> %timeit np.tensordot(X, tensor, axes=1)
100 loops, best of 3: 15.2 ms per loop

Upvotes: 1

hlin117
hlin117

Reputation: 22260

Looks like the above is equivalent to the following:

np.tensordot(X, tensor, axes=1)

axes=1, because (if the axes argument is a scalar) N should be the last axis of the first argument, and N should be the first axis of the second argument.

Upvotes: 1

Related Questions