Reputation: 263
Suppose I have a matrix such as P = [[0,1],[1,0]] and a vector v = [a,b]. If I multiply them I have:
Pv = [b,a]
The matrix P is simply a permutation matrix, which changes the order of each element.
Now suppose that I have the same P, but I have the matrices M1 = [[1,2],[3,4]] and M2=[[5,6],[7,8]]. Now let me combine them as the 3D Tensor T= [[[1,2],[3,4]], [[5,6],[7,8]]] with dimensions (2,2,2) - (C,W,H). Suppose I multiply P by T such that:
PT = [[[5,6],[7,8]], [[1,2],[3,4]]]
Note that now M1 now equals [[5,6],[7,8]] and M2 equals [[1,2],[3,4]] as the values have been permuted across the C dimension in T (C,W,H).
How can I multiply PT (P=2D tensor,T=3D tensor) in pytorch using matmul? The following does not work:
torch.matmul(P, T)
Upvotes: 1
Views: 1590
Reputation: 40618
An alternative solution to @mlucy's answer, is to use torch.einsum
. This has the benefit of defining the operation yourself, without worrying about torch.matmul
's requirements:
>>> torch.einsum('ij,jkl->ikl', P, T)
tensor([[[5, 6],
[7, 8]],
[[1, 2],
[3, 4]]])
Or with torch.matmul
:
>>> (P @ T.flatten(1)).reshape_as(T)
tensor([[[5, 6],
[7, 8]],
[[1, 2],
[3, 4]]])
Upvotes: 1
Reputation: 5289
You could do something like:
torch.matmul(P, X.flatten(1)).reshape(X.shape)
Upvotes: 0