Reputation: 2569
I have a torch.tensor
of shape (n,m)
and I want to remove the duplicated rows (or at least find them). For example:
t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6]])
t2 = remove_duplicates(t1)
t2
should be now equal to tensor([[1, 2, 3], [4, 5, 6]])
, that is rows 1
and 3
are removed. Do you know a way to perform this operation?
I was thinking to do something with torch.unique
but I cannot figure out what to do.
Upvotes: 4
Views: 6592
Reputation: 175
You can simply exploit the parameter dim of torch.unique.
t1 = torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [4, 5, 6], [7, 8, 9]])
torch.unique(t1, dim=0)
In this way you obtain the result you want:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Here you can read the meaning of that parameter.
Upvotes: 8