Reputation: 9869
Suppose I have the following tensor: y = torch.randint(0, 3, (10,))
. How would you go about counting the 0's 1's and 2's in there?
The only way I can think of is by using collections.Counter(y)
but was wondering if there was a more "pytorch" way of doing this. A use case for example would be when building the confusion matrix for predictions.
Upvotes: 10
Views: 10013
Reputation: 40618
You can use torch.unique
with the return_counts
option:
>>> x = torch.randint(0, 3, (10,))
tensor([1, 1, 0, 2, 1, 0, 1, 1, 2, 1])
>>> x.unique(return_counts=True)
(tensor([0, 1, 2]), tensor([2, 6, 2]))
Upvotes: 12