sachinruk
sachinruk

Reputation: 9869

Count Unique elements in pytorch Tensor

Suppose I have the following tensor: y = torch.randint(0, 3, (10,)). How would you go about counting the 0's 1's and 2's in there?

The only way I can think of is by using collections.Counter(y) but was wondering if there was a more "pytorch" way of doing this. A use case for example would be when building the confusion matrix for predictions.

Upvotes: 10

Views: 10013

Answers (1)

Ivan
Ivan

Reputation: 40618

You can use torch.unique with the return_counts option:

>>> x = torch.randint(0, 3, (10,))
tensor([1, 1, 0, 2, 1, 0, 1, 1, 2, 1])

>>> x.unique(return_counts=True)
(tensor([0, 1, 2]), tensor([2, 6, 2]))

Upvotes: 12

Related Questions