Theo
Theo

Reputation: 61

Theano: How to take a "matrix outer product" where the elements are matrices

Basically, I have 2 tensors: A, where A.shape = (N, H, D), and B, where B.shape = (K, H, D). What I would like to do is to get a tensor, C, with shape (N, K, D, H) such that :

C[i, j, :, :] = A[i, :, :] * B[j, :, :]. 

Can this be done efficiently in Theano?

Side note: The actual end result that I would like to achieve is to have a tensor, E, of shape (N, K, D) such that :

E[i, j, :] = (A[i, :, :]*B[j, :, :]).sum(0)

So, if there is a way to get this directly, I would prefer it (saves on space hopefully).

Upvotes: 5

Views: 845

Answers (2)

Matt Graham
Matt Graham

Reputation: 191

You can get the final three dimensional result E without creating the large intermediate array using batched_dot:

import theano.tensor as tt
A = tt.tensor3('A')  # A.shape = (D, N, H)
B = tt.tensor3('B')  # B.shape = (D, H, K)
E = tt.batched_dot(A, B)  # E.shape = (D, N, K)

Unfortunately this requires you to permute the dimensions on your input and output arrays. Though this can be done with dimshuffle in Theano it seems batched_dot can't cope with arbitrarily strided arrays and so the following raises a ValueError: Some matrix has no unit stride when E is evaluated:

import theano.tensor as tt
A = tt.tensor3('A')  # A.shape = (N, H, D)
B = tt.tensor3('B')  # B.shape = (K, H, D)
A_perm = A.dimshuffle((2, 0, 1))  # A_perm.shape = (D, N, H)
B_perm = B.dimshuffle((2, 1, 0))  # B_perm.shape = (D, H, K)
E_perm = tt.batched_dot(A_perm, B_perm)  # E_perm.shape = (D, N, K)
E = E_perm.dimshuffle((1, 2, 0))  # E.shape = (N, K, D)

batched_dot uses scan along the first (size D) dimension. As scan is performed sequentially this could be computationally less efficient than computing all the products in parallel if running on a GPU.

You can tradeoff between the memory efficiency of the batched_dot approach and parallelism in the broadcasting approach using scan explicitly. Idea would be to calculate the full product C for batches of size M in parallel (assuming M is an exact factor of D), iterating over batches with scan:

import theano as th
import theano.tensor as tt
A = tt.tensor3('A')  # A.shape = (N, H, D)
B = tt.tensor3('B')  # B.shape = (K, H, D)
A_batched = A.reshape((N, H, M, D / M))
B_batched = B.reshape((K, H, M, D / M))
E_batched, _ = th.scan(
    lambda a, b: (a[:, :, None, :] * b[:, :, :, None]).sum(1),
    sequences=[A_batched.T, B_batched.T]
)
E = E_batched.reshape((D, K, N)).T  # E.shape = (N, K, D)

Upvotes: 0

Divakar
Divakar

Reputation: 221634

One approach could be suggested that uses broadcasting -

(A[:,None]*B).sum(2)

Please note that the intermediate array being created would be of shape (N, K, H, D) before sum-reduction on axis=2 reduces it to (N,K,D).

Upvotes: 2

Related Questions