Reputation: 471
I have a tensor containing five 2x2 matrices - shape (1,5,2,2), and a tensor containing 5 elements - shape ([5]). I want to multiply each 2x2 matrix(in the former tensor) with the corresponding value (in the latter tensor). The resultant tensor should be of shape (1,5,2,2). How to do that?
Getting the following error when I run this code
a = torch.rand(1,5,2,2)
print(a.shape)
b = torch.rand(5)
print(b.shape)
mul = a*b
RuntimeError: The size of tensor a (2) must match the size of tensor b (5) at non-singleton dimension 3
Upvotes: 2
Views: 2809
Reputation: 980
You can use either a * b
or torch.mul(a, b)
but you must use permute()
before and after you multiply, in order to have the compatible shape:
import torch
a = torch.ones(1,5,2,2)
b = torch.rand(5)
a.shape # torch.Size([1, 5, 2, 2])
b.shape # torch.Size([5])
c = (a.permute(0,2,3,1) * b).permute(0,3,1,2)
c.shape # torch.Size([1, 5, 2, 2])
# OR #
c = torch.mul(a.permute(0,2,3,1), b).permute(0,3,1,2)
c.shape # torch.Size([1, 5, 2, 2])
The permute()
function transposes the dimention in the order of it's arguments. I.e, a.permute(0,2,3,1)
will be of shape torch.Size([1, 2, 2, 5]) which fits the shape of b
(torch.Size([5])) for matrix multiplication, since the last dimention of a
equals the first dimention of b
. After we finish the multiplication we transpose it again, using permute()
, to the. desired shape of torch.Size([1, 5, 2, 2]) by permute(0,3,1,2).
You can read about permute()
in the docs. But it works with it's arguments numbering the current shape of [1, 5, 2, 2] by 0 to 3, and permutes as you inserted the arguments, meaning for a.permute(0,2,3,1)
it will keep the first dimention in its place, since the first argument is 0, the second dimention it will move to the forth dimention, since the index 1 is the forth argument. And the third and forth dimention will move to the second and third dimention, because the 2 and 3 indices are located in the second and third place. Remember when talking about the 4th dimention for instance, its representation as an argument is 3 (not 4).
EDIT
If you want to element-wise multiply tensors of shape [32,5,2,2] and [32,5] for example, such that each 2x2 matrix will be multiplied by the corresponding value, you could rearrange the dimentions as [2,2,32,5] by permute(2,3,0,1)
, then perform the multiplication by a * b
and then return to the original shape by permute(2,3,0,1)
again. The key here, is that the last n
dimentions of the first matrix, need to align with the first n
dimentions of the second matrix. In our case n=2
.
Hope that helps.
Upvotes: 2