aghtaal
aghtaal

Reputation: 307

Recommended way to replace several values in a tensor at once?

Is there a batch way to replace several particular values in a pytorch tensor at once without a for loop?

Example:

old_values = torch.Tensor([1, 2, 3, 4, 5, 5, 2, 3, 3, 2])
old_new_value = [[2,22], [3,33], [6, 66]]

old_new_value = [[2,22], [3,33], [6, 66]], which means 2 should be replaced by 22, and 3 should be replaced by 33 and 6 to 66

Can I have an efficient way to achieve the following end_result?

end_result = torch.Tensor([1, 22, 33, 4, 5, 5, 22, 33, 33, 22])

Note that old_values is not unique. Also, it is possible that old_new_value has a pair here(6, 66) that does not exist in the old_values. Also, the old_new_values includes unique rows,

Upvotes: 4

Views: 7753

Answers (2)

Luis Herrmann
Luis Herrmann

Reputation: 1

Not sure anyone still cares about this, but just in case, here is a solution that also works when old_values is not unique:

mask = old_values == old_new_value[:, :1]
new_values = (1 - mask.sum(dim=0)) * old_values + (mask * old_new_value[:,1:]).sum(dim=0)

Masking works as in @kmario23's solution, but the mask is multiplied with the new values and sum-reduced to end up with the new values at all the right replacement positions. The negative mask is applied to the old values to use those at all other positions. Then both masked tensors are summed to obtain the desired result.

Upvotes: 0

kmario23
kmario23

Reputation: 61415

If you don't have any duplicate elements in your input tensor, here's one straightforward way using masking and value assignment using basic indexing. (I'll assume that the data type of the input tensor is int. But, you can simply adapt this code in a straightforward manner to other dtypes). Below is a reproducible illustration, with explanations interspersed in inline comments.

# input tensors to work with
In [75]: old_values    
Out[75]: tensor([1, 2, 3, 4, 5], dtype=torch.int32)

In [77]: old_new_value      
Out[77]:
tensor([[ 2, 22],
        [ 3, 33]], dtype=torch.int32)

# generate a boolean mask using the values that need to be replaced (i.e. 2 & 3)
In [78]: boolean_mask = (old_values == old_new_value[:, :1]).sum(dim=0).bool() 

In [79]: boolean_mask 
Out[79]: tensor([False,  True,  True, False, False])

# assign the new values by basic indexing
In [80]: old_values[boolean_mask] = old_new_value[:, 1:].squeeze() 

# sanity check!
In [81]: old_values 
Out[81]: tensor([ 1, 22, 33,  4,  5], dtype=torch.int32)

A small note on efficiency: Throughout the whole process, we never made any copy of the data (i.e. we operate only on new views by massaging the shapes according to our needs). Therefore, the runtime would be blazing fast.

Upvotes: 4

Related Questions