Reputation: 1003
My objective it to count all adjacent unique values of a tensor x
.
Say my tensor is (x
looks like a list but it is a pytorch tensor)
x = [1,2,1,2,4,5]
I would want my output to be:
[1,2] = 2
[2,1] = 1
[2,4] = 1
[4,5] = 1
I thought about changing the dimensionality of the tensor to look like:
x = [[1,2],[2,1],[1,2],[2,4],[4,5]]
using tensor.view
but couldn't find a solution that works for a tensor of any length.
Any ideas if this is even the best way to go about this? is there some built-in function?
Upvotes: 3
Views: 106
Reputation: 40618
As @ihdv showed, you can stack shifted views of x
with torch.stack
or torch.vstack
in order to get a tensor of pairs with overlapping windows.
>>> p = torch.vstack((x[:-1], x[1:]))
tensor([[1., 2., 1., 2., 4.],
[2., 1., 2., 4., 5.]])
Then you can apply torch.unique
on it to get the statistics:
>>> p.unique(dim=1, return_counts=True)
(tensor([[1., 2., 2., 4.],
[2., 1., 4., 5.]]), tensor([2, 1, 1, 1]))
Upvotes: 2