Sourav
Sourav

Reputation: 866

Deleting Rows in Torch Tensor

I have a torch tensor as follows -

a = tensor(
[[0.2215, 0.5859, 0.4782, 0.7411],
[0.3078, 0.3854, 0.3981, 0.5200],
[0.1363, 0.4060, 0.2030, 0.4940],
[0.1640, 0.6025, 0.2267, 0.7036],
[0.2445, 0.3032, 0.3300, 0.4253]],  dtype=torch.float64)

If the first value of each row is less than 0.2 then the whole row needs to be deleted. Thus I need the output like -

tensor(
[[0.2215, 0.5859, 0.4782, 0.7411],
[0.3078, 0.3854, 0.3981, 0.5200],
[0.2445, 0.3032, 0.3300, 0.4253]],  dtype=torch.float64)

I have tried to loop through the tensor and append the valid value to a new empty tensor but was not successful. Is there any way to get the results efficiently?

Upvotes: 5

Views: 19205

Answers (1)

Wasi Ahmad
Wasi Ahmad

Reputation: 37741

Code

a = torch.Tensor(
    [[0.2215, 0.5859, 0.4782, 0.7411],
    [0.3078, 0.3854, 0.3981, 0.5200],
    [0.1363, 0.4060, 0.2030, 0.4940],
    [0.1640, 0.6025, 0.2267, 0.7036],
    [0.2445, 0.3032, 0.3300, 0.4253]])

y = a[a[:, 0] > 0.2]
print(y)

Output

tensor([[0.2215, 0.5859, 0.4782, 0.7411],
        [0.3078, 0.3854, 0.3981, 0.5200],
        [0.2445, 0.3032, 0.3300, 0.4253]])

Upvotes: 10

Related Questions