amnesic
amnesic

Reputation: 267

Is there a way to show actual calculations in PyTorch?

I am learning PyTorch and am new to operations with tensors.

If I have a tensor img_t of shape [3, 5, 5] like this:

       [[[0.1746, 0.3302, 0.5370, 0.8443, 0.6937],
         [0.8831, 0.1861, 0.5422, 0.0556, 0.7868],
         [0.6042, 0.9836, 0.1444, 0.9010, 0.9221],
         [0.9043, 0.5713, 0.9546, 0.8339, 0.8730],
         [0.4675, 0.1163, 0.4938, 0.5938, 0.1594]],

        [[0.2132, 0.0206, 0.3247, 0.9355, 0.5855],
         [0.4695, 0.5201, 0.8118, 0.0585, 0.1142],
         [0.3338, 0.2122, 0.7579, 0.8533, 0.0149],
         [0.0757, 0.0131, 0.6886, 0.9024, 0.1123],
         [0.2685, 0.6591, 0.1735, 0.9247, 0.6166]],

        [[0.3608, 0.5325, 0.6559, 0.3232, 0.1126],
         [0.5034, 0.5091, 0.5101, 0.4270, 0.8210],
         [0.3605, 0.4516, 0.7056, 0.1853, 0.6339],
         [0.3894, 0.7398, 0.2288, 0.5185, 0.5489],
         [0.0977, 0.1364, 0.6918, 0.3545, 0.7969]]]

I get that that img_t.mean(0) takes mean across all 3 arrays corresponding to their indices in each array returning a 5x5 tensor. And I get that img_t.mean(1) takes mean across each array along the columns returning a 3x5 tensor. But what does img_t.mean(2) do? The only guess I can make is it takes mean across each array along the rows but that would return a 5x5 tensor but actual results is a 3x5 tensor.

I tried reading about docs and other SO posts and running my own numbers but still can’t figure it out. It would be really helpful if I could see what calculation PyTorch is actually doing while I'm still learning. Please advise.

Upvotes: 1

Views: 88

Answers (1)

core_not_dumped
core_not_dumped

Reputation: 877

The mean input determines which dimension to apply.

To make the code more readable, it would be:

img_t.mean(dim = 0).shape

And img_t.mean(2) is not given as a 5x5 tensor, but as you would think, given as a 3x5 tensor.

import torch

img_t = torch.tensor([[[0.1746, 0.3302, 0.5370, 0.8443, 0.6937],
         [0.8831, 0.1861, 0.5422, 0.0556, 0.7868],
         [0.6042, 0.9836, 0.1444, 0.9010, 0.9221],
         [0.9043, 0.5713, 0.9546, 0.8339, 0.8730],
         [0.4675, 0.1163, 0.4938, 0.5938, 0.1594]],

        [[0.2132, 0.0206, 0.3247, 0.9355, 0.5855],
         [0.4695, 0.5201, 0.8118, 0.0585, 0.1142],
         [0.3338, 0.2122, 0.7579, 0.8533, 0.0149],
         [0.0757, 0.0131, 0.6886, 0.9024, 0.1123],
         [0.2685, 0.6591, 0.1735, 0.9247, 0.6166]],

        [[0.3608, 0.5325, 0.6559, 0.3232, 0.1126],
         [0.5034, 0.5091, 0.5101, 0.4270, 0.8210],
         [0.3605, 0.4516, 0.7056, 0.1853, 0.6339],
         [0.3894, 0.7398, 0.2288, 0.5185, 0.5489],
         [0.0977, 0.1364, 0.6918, 0.3545, 0.7969]]])

print(img_t.mean(dim = 0).shape)
print(img_t.mean(dim = 1).shape)
print(img_t.mean(dim = 2).shape)

result:

torch.Size([5, 5])
torch.Size([3, 5])
torch.Size([3, 5])

Upvotes: 3

Related Questions