Joon. P
Joon. P

Reputation: 2298

PyTorch broadcast multiplication of 4D and 2D matrix?

How do I broadcast to multiply these two matrices together?

x: torch.Size([10, 120, 180, 30]) # (N, H, W, C)
W: torch.Size([64, 30]) # (Y, C)

The output should be:

(10, 120, 180, 64) == (N, H, W, Y)

Upvotes: 0

Views: 919

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24701

I assume x is some kind of example with batches and w matrix is the corresponding weight. In this case you could simply do:

out = x @ w.T

which is a tensor multiplication, not an element-wise one. You can't do element-wise multiplication to get such shape and this operation would not make sense. All you could do is to unsqueeze both of the matrics in some way to have their shape broadcastable and apply some operation over dimension you don't want for some reason like this:

x : torch.Size([10, 120, 180, 30, 1])
W: torch.Size([1, 1, 1, 30, 64]) # transposition would be needed as well

After such unsqueezing you could do x*w and sum or mean along the third dim to get desired shape.

For clarity, both ways are not equivalent.

Upvotes: 1

Related Questions