phil
phil

Reputation: 263

Multiply 2D tensor with 3D tensor in pytorch

Suppose I have a matrix such as P = [[0,1],[1,0]] and a vector v = [a,b]. If I multiply them I have:

Pv = [b,a]

The matrix P is simply a permutation matrix, which changes the order of each element.

Now suppose that I have the same P, but I have the matrices M1 = [[1,2],[3,4]] and M2=[[5,6],[7,8]]. Now let me combine them as the 3D Tensor T= [[[1,2],[3,4]], [[5,6],[7,8]]] with dimensions (2,2,2) - (C,W,H). Suppose I multiply P by T such that:

PT = [[[5,6],[7,8]], [[1,2],[3,4]]]

Note that now M1 now equals [[5,6],[7,8]] and M2 equals [[1,2],[3,4]] as the values have been permuted across the C dimension in T (C,W,H).

How can I multiply PT (P=2D tensor,T=3D tensor) in pytorch using matmul? The following does not work:

torch.matmul(P, T)

Upvotes: 1

Views: 1590

Answers (2)

Ivan
Ivan

Reputation: 40618

An alternative solution to @mlucy's answer, is to use torch.einsum. This has the benefit of defining the operation yourself, without worrying about torch.matmul's requirements:

>>> torch.einsum('ij,jkl->ikl', P, T)
tensor([[[5, 6],
         [7, 8]],

        [[1, 2],
         [3, 4]]])

Or with torch.matmul:

>>> (P @ T.flatten(1)).reshape_as(T)
tensor([[[5, 6],
         [7, 8]],

        [[1, 2],
         [3, 4]]])

Upvotes: 1

mlucy
mlucy

Reputation: 5289

You could do something like:

torch.matmul(P, X.flatten(1)).reshape(X.shape)

Upvotes: 0

Related Questions