Reputation: 219
I'm trying to delete an item from a tensor.
In the example below, How can I remove the third item from the tensor ?
tensor([[-5.1949, -6.2621, -6.2051, -5.8983, -6.3586, -6.2434, -5.8923, -6.1901,
-6.5713, -6.2396, -6.1227, -6.4196, -3.4311, -6.8903, -6.1248, -6.3813,
-6.0152, -6.7449, -6.0523, -6.4341, -6.8579, -6.1961, -6.5564, -6.6520,
-5.9976, -6.3637, -5.7560, -6.7946, -5.4101, -6.1310, -3.3249, -6.4584,
-6.2202, -6.3663, -6.9293, -6.9262]], grad_fn=<SqueezeBackward1>)
Upvotes: 11
Views: 39630
Reputation: 21
If you wanted to remove multiple elements by index you could do the following.
# Removing elements
x = torch.arange(10)
y = torch.arange(4)*2
print(x)
print(y)
x[y] = -1
print(x[x != -1])
This gives output
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
tensor([0, 2, 4, 6])
tensor([1, 3, 5, 7, 8, 9])
Upvotes: 2
Reputation: 2636
I think that doing this with indexing is more readable.
t[t!=t[0,3]]
The result is the same as with the cat
solution from below.
BE CAREFUL: This will usually work for floats, but beware that if the value at [0,3]
occurs more than once in the array, you will remove all occurrences of this item.
Upvotes: 7
Reputation: 7713
You can first filter array through indices and then concat both
t.shape
torch.Size([1, 36])
t = torch.cat((t[:,:3], t[:,4:]), axis = 1)
t.shape
torch.Size([1, 35])
Upvotes: 2