onthebox
onthebox

Reputation: 15

Map each element of torch.Tensor with it's value in the dict

Suppose i have a tensor t consisting only zeros and ones:

t = torch.Tensor([1, 0, 0, 1])

And a dict with the weights:

weights = {0: 0.1, 1: 0.9}

I want to form a new tensor new_t, such that every element in tensor t is mapped to the corresponding value in the dict weights:

new_t = torch.Tensor([0.9, 0.1, 0.1, 0.9])

Is there an elegant way to do this without iterating over tensor t? I've heard about torch.apply, but it only works if tensor t is on the CPU, is there any other options?

Upvotes: 1

Views: 32

Answers (1)

Karl
Karl

Reputation: 5303

If you convert your weights dict into a tensor, you can index directly

t = torch.tensor([1, 0, 0, 1])
weights = torch.tensor([0.1, 0.9])

new_t = weights[t]
new_t
>tensor([0.9000, 0.1000, 0.1000, 0.9000])

Upvotes: 1

Related Questions