Reputation: 329
I'm working on PyTorch and currently I met a problem for which I've no idea how to solve it in a torch/numpy style. For example, suppose I have three PyTorch tensors
import torch
import numpy as np
indices = torch.from_numpy(np.array([[2, 1, 3, 0], [1, 0, 3, 2]]))
flags = torch.from_numpy(np.array([[False, False, False, True], [False, False, True, True]]))
tensor = torch.from_numpy(np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]]))
Here flags
is a boolean flag tensor to show which elements in indices
should be extracted. Given the extracted indices, I want to set the corresponding elements in tensor
to an indicated const (say 1e-30). Based on the example shown above, I want
>>> sub_indices = indices.op1(flags)
>>> sub_indices
tensor([[0], [3, 2]])
>>> tensor.op2(sub_indices, 1e-30)
>>> tensor
tensor([[1e-30, 0.5, 1.2, 0.9], [3.1, 2.8, 1e-30, 1e-30]])
Could anyone help to give a solution? I'm using list comprehension but I think this way is a little bit ugly. I tried indices[flags]
but it only returns a 1d-array [0, 3, 2]
so applying this would change all rows on the same columns 0, 2, 3
Some additional remarks:
flags
cannot be determinedindices
is assured to be a permutation of sequence 0 ... N - 1
Below is a numpy version of the example code, for the convenience of copy-pasting. I doubt whether this could be done in a pure numpy way
import numpy as np
indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
Upvotes: 1
Views: 1440
Reputation: 1040
You may sort flags
according to the indices
to create a mask
, then use the mask
as a mux. Here is an example code:
indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
indices_sorted = indices.argsort(axis=1)
mask = np.take_along_axis(flags, indices_sorted, axis=1)
result = tensor * (1 - mask) + 1e-30 * mask
I'm not quite familiar with pytorch, but I guess it is not a good idea to gather a ragged tensor. Though, even in the worst case, you can convert to/from numpy arrays.
Upvotes: 2
Reputation: 4475
The pytorch version of @soloice's solution. In pytorch, torch.gather
is used instead of torch.take
.
indices = torch.tensor([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = torch.tensor([[False, False, False, True], [False, False, True, True]])
tensor = torch.tensor([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])
indices_sorted = indices.argsort(axis=1)
mask = torch.gather(flags, 1, indices_sorted).float()
result = tensor * (1 - mask) + 1e-30 * mask
Upvotes: 2