Luca Clissa
Luca Clissa

Reputation: 908

How to add two torch tensors along given dimension?

I have a torch tensor, pred, in the form (B, 2, H, W) and I want to sum two different values, val1 and val2, to the channels on axis 1.

I managed to do it in a "mechanical" way by accessing the single channels directly, e.g.:

def thresh_format(pred, val1, val2):
    tr = torch.zeros_like(pred)
    tr[:, 0, :, :] = tr[:, 0, :, :].add(val1)
    tr[:, 1, :, :] = tr[:, 1, :, :].add(val2)
    return pred + tr

However I'm wondering if there's a "better" way to do it, e.g. by exploiting broadcasting. My understanding from the documentation is that broadcasting happens from trailing dimensions, so in this case I'm puzzled how to make it work for dimension 1. Any ideas?

Upvotes: 0

Views: 4125

Answers (1)

lwohlhart
lwohlhart

Reputation: 1932

The easiest way to achieve this is to stack val1 and val2 in a tensor and reshape it to match the shape of the pred tensor along the common dimension.

 pred + torch.tensor([val1, val2]).reshape((1,-1,1,1))

This way, for the addition, torch automatically broadcasts the values along the dimensions where pred has higher order.

It's pretty similar to what happens when you just add a simple scalar value to a tensor, like:

>>> torch.ones((2, 2)) + 3.
tensor([[4., 4.],
        [4., 4.]])

But instead of broadcasting the one scalar value to every element of the tensor during the addition, in the aforementioned case the values are broadcasted along the dimensions that do not already match.

>>> B=1; W=2; H=2; val1=3; val2=7
>>> pred = torch.zeros((B,2,W,H))
>>> val = torch.tensor([val1, val2]).reshape((1,-1,1,1))
>>> pred
tensor([[[[0., 0.],
          [0., 0.]],

         [[0., 0.],
          [0., 0.]]]])
>>> val
tensor([[[[3]],

         [[7]]]])
>>> pred + val
tensor([[[[3., 3.],
          [3., 3.]],

         [[7., 7.],
          [7., 7.]]]])

Upvotes: 2

Related Questions