french_fries
french_fries

Reputation: 1

What does [1,2] means in .mean([1,2]) for tensor?

I have a tensor with shape torch.Size([3, 224, 225]). when I do tensor.mean([1,2]) I get tensor([0.6893, 0.5840, 0.4741]). What does [1,2] mean here?

Upvotes: 2

Views: 565

Answers (2)

smoks
smoks

Reputation: 132

The shape of your tensor is 3 across dimension 0, 224 across dimension 1 and 225 across dimension 2.

I would say that tensor.mean([1,2]) calculates the mean across dimension 1 as well as dimension 2. Thats why you are getting 3 values. Each plane spanned by dimension 1 and 2 of size 224x225 is reduced to a single value / scalar. Since there are 3 planes that are spanned by dimension 1 and 2 of size 224x225 you get 3 values back. Each value represents the mean of a whole plane with 224x225 values.

Upvotes: 0

DerekG
DerekG

Reputation: 3958

Operations that aggregate along dimensions like min,max,mean,sum, etc. specify the dimension along which to aggregate. It is common to use these operations across every dimension (i.e. get the mean for the entire tensor) or a single dimension (i.e. torch.mean(dim = 2) or torch.mean(2) returns the mean of the 225 elements for each of 3 x 224 vectors.

Pytorch also allows these operations across a set of multiple dimensions, such as in your case. This means to take the mean of the 224 x 224 elements for each of the indices along the 0th (non-aggregated dimension). Likewise, if your original tensor shape was a.shape = torch.Size([3,224,10,225]), a.mean([1,3]) would return a tensor of shape [3,10].

Upvotes: 1

Related Questions