Reputation: 105
This is an issue I'm running while convertinf
DQN to Double DQN for the cartpole
problem. I'm getting close to figuring it out.
tensor([0.1205, 0.1207, 0.1197, 0.1195, 0.1204, 0.1205, 0.1208, 0.1199, 0.1206,
0.1199, 0.1204, 0.1205, 0.1199, 0.1204, 0.1204, 0.1203, 0.1198, 0.1198,
0.1205, 0.1204, 0.1201, 0.1205, 0.1208, 0.1202, 0.1205, 0.1203, 0.1204,
0.1205, 0.1206, 0.1206, 0.1205, 0.1204, 0.1201, 0.1206, 0.1206, 0.1199,
0.1198, 0.1200, 0.1206, 0.1207, 0.1208, 0.1202, 0.1201, 0.1210, 0.1208,
0.1205, 0.1205, 0.1201, 0.1193, 0.1201, 0.1205, 0.1207, 0.1207, 0.1195,
0.1210, 0.1204, 0.1209, 0.1207, 0.1187, 0.1202, 0.1198, 0.1202])
tensor([ True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, False, True, True, True,
True, True, True, True, True, True, True, False, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True, True, True, True, True, True, True,
True, True, True, True])
As you can see here two tensors.
The first
has the q values I want
but,
some values need to be changed to zeros because of it an end state.
The second
tensor shows where it will be zeros.
At the index where the Boolean value is false is the equivalent spot for where the upper tensor needs to be zeros. I am not sure how to do that.
Upvotes: 6
Views: 15654
Reputation: 7693
You can use torch.where
- torch.where(condition, x, y)
Ex.:
>>> x = tensor([0.2853, 0.5010, 0.9933, 0.5880, 0.3915, 0.0141, 0.7745,
0.0588, 0.4939, 0.0849])
>>> condition = tensor([False, True, True, True, False, False, True,
False, False, False])
>>> # It's equivalent to `torch.where(condition, x, tensor(0.0))`
>>> x.where(condition, tensor(0.0))
tensor([0.0000, 0.5010, 0.9933, 0.5880, 0.0000, 0.0000, 0.7745,
0.0000, 0.0000,0.0000])
Upvotes: 6
Reputation: 1215
If your above tensor is the value tensor and the bottom one is the decision tensor, then
value_tensor[decision_tensor==False] = 0
Moreover, you could also convert them to numpy arrays and perform the same operation and it should work.
Upvotes: 4