Reputation: 828
I am programming an othello bot in python using reinforcment learning and pytorch. In the program, I scan the board for legal moves. The AI should choose the move with the highest probability of beeing good, and that is legal according to the previose calculation. Here I need a function that works something like this:
a = torch.tensor([1,2,3,4,5])
b = torch.tensor([True, True, False, True, False], dtype=bool)
print(torch.somefunction(a,b))
The output should be the id of the max value in a, in this case 3. Does this function exist? And if not, are there any other whay of doing this?
Upvotes: 1
Views: 1189
Reputation: 40778
Assuming there is at least one non-negative value in your tensor, you multiply it by the mask itself to remove excluded values in the sorting:
>>> torch.argmax(a*b)
tensor(3)
If that's not the case, you can still get away with it, using torch.where
by replacing the excluded values with some value that will get ignored by the argmax (e.g. a.min()
):
>>> torch.where(b, a, a.min()).argmax()
tensor(3)
Upvotes: 1