DmiSH
DmiSH

Reputation: 105

How to change certain values in a torch tensor based on an index in another torch tensor?

This is an issue I'm running while convertinf DQN to Double DQN for the cartpole problem. I'm getting close to figuring it out.

tensor([0.1205, 0.1207, 0.1197, 0.1195, 0.1204, 0.1205, 0.1208, 0.1199, 0.1206,
        0.1199, 0.1204, 0.1205, 0.1199, 0.1204, 0.1204, 0.1203, 0.1198, 0.1198,
        0.1205, 0.1204, 0.1201, 0.1205, 0.1208, 0.1202, 0.1205, 0.1203, 0.1204,
        0.1205, 0.1206, 0.1206, 0.1205, 0.1204, 0.1201, 0.1206, 0.1206, 0.1199,
        0.1198, 0.1200, 0.1206, 0.1207, 0.1208, 0.1202, 0.1201, 0.1210, 0.1208,
        0.1205, 0.1205, 0.1201, 0.1193, 0.1201, 0.1205, 0.1207, 0.1207, 0.1195,
        0.1210, 0.1204, 0.1209, 0.1207, 0.1187, 0.1202, 0.1198, 0.1202])
tensor([ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True, False,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True, False,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
         True,  True,  True,  True])

As you can see here two tensors. The first has the q values I want but, some values need to be changed to zeros because of it an end state. The second tensor shows where it will be zeros.

At the index where the Boolean value is false is the equivalent spot for where the upper tensor needs to be zeros. I am not sure how to do that.

Upvotes: 6

Views: 15654

Answers (2)

Dishin H Goyani
Dishin H Goyani

Reputation: 7693

You can use torch.where - torch.where(condition, x, y)

Ex.:

>>> x = tensor([0.2853, 0.5010, 0.9933, 0.5880, 0.3915, 0.0141, 0.7745,  
                0.0588, 0.4939, 0.0849])
>>> condition = tensor([False,  True,  True,  True, False, False,  True,  
                        False, False, False])

>>> # It's equivalent to `torch.where(condition, x, tensor(0.0))`
>>> x.where(condition, tensor(0.0))
tensor([0.0000, 0.5010, 0.9933, 0.5880, 0.0000, 0.0000, 0.7745,  
        0.0000, 0.0000,0.0000])

Upvotes: 6

Anurag Reddy
Anurag Reddy

Reputation: 1215

If your above tensor is the value tensor and the bottom one is the decision tensor, then

value_tensor[decision_tensor==False] = 0

Moreover, you could also convert them to numpy arrays and perform the same operation and it should work.

Upvotes: 4

Related Questions