Shamoon
Shamoon

Reputation: 43491

How can I count the number of 1's and 0's in the second dimension of a PyTorch Tensor?

I have a tensor with size: torch.Size([64, 2941]), which is 64 batches of 2941 elements.

Across all 64 batches, I want to count the total number of 1s and 0s in the second dimension of the tensor, all the way to the 2941th so that I will have those counts as a tensor of size torch.Size([2941])

How do I do that?

Upvotes: 3

Views: 12349

Answers (1)

Berriel
Berriel

Reputation: 13601

You can sum them:

import torch
torch.manual_seed(2020)

# x is a fake data, it should be your tensor with 64x2941
x = (torch.rand((3,4)) > 0.5).to(torch.int32)
print(x)
# tensor([[0, 0, 1, 0],
#         [0, 0, 1, 1],
#         [0, 1, 1, 1]], dtype=torch.int32)

ones = (x == 1.).sum(dim=0)
print(ones)
# tensor([0, 1, 3, 2])

And if x is binary, you can have the number of zeros by a simple subtraction:

zeros = x.shape[1] - ones

Upvotes: 9

Related Questions