Reputation: 61
Basically, I have 2 tensors: A, where A.shape = (N, H, D)
, and B, where B.shape = (K, H, D)
. What I would like to do is to get a tensor, C, with shape (N, K, D, H)
such that :
C[i, j, :, :] = A[i, :, :] * B[j, :, :].
Can this be done efficiently in Theano?
Side note: The actual end result that I would like to achieve is to have a tensor, E, of shape (N, K, D)
such that :
E[i, j, :] = (A[i, :, :]*B[j, :, :]).sum(0)
So, if there is a way to get this directly, I would prefer it (saves on space hopefully).
Upvotes: 5
Views: 845
Reputation: 191
You can get the final three dimensional result E
without creating the large intermediate array using batched_dot
:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (D, N, H)
B = tt.tensor3('B') # B.shape = (D, H, K)
E = tt.batched_dot(A, B) # E.shape = (D, N, K)
Unfortunately this requires you to permute the dimensions on your input and output arrays. Though this can be done with dimshuffle
in Theano it seems batched_dot
can't cope with arbitrarily strided arrays and so the following raises a ValueError: Some matrix has no unit stride
when E
is evaluated:
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_perm = A.dimshuffle((2, 0, 1)) # A_perm.shape = (D, N, H)
B_perm = B.dimshuffle((2, 1, 0)) # B_perm.shape = (D, H, K)
E_perm = tt.batched_dot(A_perm, B_perm) # E_perm.shape = (D, N, K)
E = E_perm.dimshuffle((1, 2, 0)) # E.shape = (N, K, D)
batched_dot
uses scan
along the first (size D
) dimension. As scan
is performed sequentially this could be computationally less efficient than computing all the products in parallel if running on a GPU.
You can tradeoff between the memory efficiency of the batched_dot
approach and parallelism in the broadcasting approach using scan
explicitly. Idea would be to calculate the full product C
for batches of size M
in parallel (assuming M
is an exact factor of D
), iterating over batches with scan
:
import theano as th
import theano.tensor as tt
A = tt.tensor3('A') # A.shape = (N, H, D)
B = tt.tensor3('B') # B.shape = (K, H, D)
A_batched = A.reshape((N, H, M, D / M))
B_batched = B.reshape((K, H, M, D / M))
E_batched, _ = th.scan(
lambda a, b: (a[:, :, None, :] * b[:, :, :, None]).sum(1),
sequences=[A_batched.T, B_batched.T]
)
E = E_batched.reshape((D, K, N)).T # E.shape = (N, K, D)
Upvotes: 0
Reputation: 221634
One approach could be suggested that uses broadcasting
-
(A[:,None]*B).sum(2)
Please note that the intermediate array being created would be of shape (N, K, H, D)
before sum-reduction on axis=2
reduces it to (N,K,D).
Upvotes: 2