TX Shi
TX Shi

Reputation: 329

Any numpy/torch style to set value given an index ndarray and a flag ndarray?

I'm working on PyTorch and currently I met a problem for which I've no idea how to solve it in a torch/numpy style. For example, suppose I have three PyTorch tensors

import torch
import numpy as np

indices = torch.from_numpy(np.array([[2, 1, 3, 0], [1, 0, 3, 2]]))
flags = torch.from_numpy(np.array([[False, False, False, True], [False, False, True, True]]))
tensor = torch.from_numpy(np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]]))

Here flags is a boolean flag tensor to show which elements in indices should be extracted. Given the extracted indices, I want to set the corresponding elements in tensor to an indicated const (say 1e-30). Based on the example shown above, I want

>>> sub_indices = indices.op1(flags)
>>> sub_indices
tensor([[0], [3, 2]])
>>> tensor.op2(sub_indices, 1e-30)
>>> tensor
tensor([[1e-30, 0.5, 1.2, 0.9], [3.1, 2.8, 1e-30, 1e-30]])

Could anyone help to give a solution? I'm using list comprehension but I think this way is a little bit ugly. I tried indices[flags] but it only returns a 1d-array [0, 3, 2] so applying this would change all rows on the same columns 0, 2, 3

Some additional remarks:

Below is a numpy version of the example code, for the convenience of copy-pasting. I doubt whether this could be done in a pure numpy way

import numpy as np

indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

Upvotes: 1

Views: 1440

Answers (2)

soloice
soloice

Reputation: 1040

You may sort flags according to the indices to create a mask, then use the mask as a mux. Here is an example code:

indices = np.array([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = np.array([[False, False, False, True], [False, False, True, True]])
tensor = np.array([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

indices_sorted = indices.argsort(axis=1)
mask = np.take_along_axis(flags, indices_sorted, axis=1)
result = tensor * (1 - mask) + 1e-30 * mask

I'm not quite familiar with pytorch, but I guess it is not a good idea to gather a ragged tensor. Though, even in the worst case, you can convert to/from numpy arrays.

Upvotes: 2

zihaozhihao
zihaozhihao

Reputation: 4475

The pytorch version of @soloice's solution. In pytorch, torch.gather is used instead of torch.take.

indices = torch.tensor([[2, 1, 3, 0], [1, 0, 3, 2]])
flags = torch.tensor([[False, False, False, True], [False, False, True, True]])
tensor = torch.tensor([[2.8, 0.5, 1.2, 0.9], [3.1, 2.8, 1.3, 2.5]])

indices_sorted = indices.argsort(axis=1)
mask = torch.gather(flags, 1, indices_sorted).float()
result = tensor * (1 - mask) + 1e-30 * mask

Upvotes: 2

Related Questions