Reputation: 267
I am learning PyTorch
and am new to operations with tensors.
If I have a tensor img_t
of shape [3, 5, 5] like this:
[[[0.1746, 0.3302, 0.5370, 0.8443, 0.6937],
[0.8831, 0.1861, 0.5422, 0.0556, 0.7868],
[0.6042, 0.9836, 0.1444, 0.9010, 0.9221],
[0.9043, 0.5713, 0.9546, 0.8339, 0.8730],
[0.4675, 0.1163, 0.4938, 0.5938, 0.1594]],
[[0.2132, 0.0206, 0.3247, 0.9355, 0.5855],
[0.4695, 0.5201, 0.8118, 0.0585, 0.1142],
[0.3338, 0.2122, 0.7579, 0.8533, 0.0149],
[0.0757, 0.0131, 0.6886, 0.9024, 0.1123],
[0.2685, 0.6591, 0.1735, 0.9247, 0.6166]],
[[0.3608, 0.5325, 0.6559, 0.3232, 0.1126],
[0.5034, 0.5091, 0.5101, 0.4270, 0.8210],
[0.3605, 0.4516, 0.7056, 0.1853, 0.6339],
[0.3894, 0.7398, 0.2288, 0.5185, 0.5489],
[0.0977, 0.1364, 0.6918, 0.3545, 0.7969]]]
I get that that img_t.mean(0)
takes mean across all 3 arrays corresponding to their indices in each array returning a 5x5 tensor. And I get that img_t.mean(1)
takes mean across each array along the columns returning a 3x5 tensor. But what does img_t.mean(2)
do? The only guess I can make is it takes mean across each array along the rows but that would return a 5x5 tensor but actual results is a 3x5 tensor.
I tried reading about docs and other SO posts and running my own numbers but still can’t figure it out. It would be really helpful if I could see what calculation PyTorch is actually doing while I'm still learning. Please advise.
Upvotes: 1
Views: 88
Reputation: 877
The mean input determines which dimension to apply.
To make the code more readable, it would be:
img_t.mean(dim = 0).shape
And img_t.mean(2) is not given as a 5x5 tensor, but as you would think, given as a 3x5 tensor.
import torch
img_t = torch.tensor([[[0.1746, 0.3302, 0.5370, 0.8443, 0.6937],
[0.8831, 0.1861, 0.5422, 0.0556, 0.7868],
[0.6042, 0.9836, 0.1444, 0.9010, 0.9221],
[0.9043, 0.5713, 0.9546, 0.8339, 0.8730],
[0.4675, 0.1163, 0.4938, 0.5938, 0.1594]],
[[0.2132, 0.0206, 0.3247, 0.9355, 0.5855],
[0.4695, 0.5201, 0.8118, 0.0585, 0.1142],
[0.3338, 0.2122, 0.7579, 0.8533, 0.0149],
[0.0757, 0.0131, 0.6886, 0.9024, 0.1123],
[0.2685, 0.6591, 0.1735, 0.9247, 0.6166]],
[[0.3608, 0.5325, 0.6559, 0.3232, 0.1126],
[0.5034, 0.5091, 0.5101, 0.4270, 0.8210],
[0.3605, 0.4516, 0.7056, 0.1853, 0.6339],
[0.3894, 0.7398, 0.2288, 0.5185, 0.5489],
[0.0977, 0.1364, 0.6918, 0.3545, 0.7969]]])
print(img_t.mean(dim = 0).shape)
print(img_t.mean(dim = 1).shape)
print(img_t.mean(dim = 2).shape)
result:
torch.Size([5, 5])
torch.Size([3, 5])
torch.Size([3, 5])
Upvotes: 3