true equals false
true equals false

Reputation: 828

Is there a function for picking the max value in a pytorch tensor which follows a requirenment

I am programming an othello bot in python using reinforcment learning and pytorch. In the program, I scan the board for legal moves. The AI should choose the move with the highest probability of beeing good, and that is legal according to the previose calculation. Here I need a function that works something like this:

a = torch.tensor([1,2,3,4,5])
b = torch.tensor([True, True, False, True, False], dtype=bool)
print(torch.somefunction(a,b))

The output should be the id of the max value in a, in this case 3. Does this function exist? And if not, are there any other whay of doing this?

Upvotes: 1

Views: 1189

Answers (1)

Ivan
Ivan

Reputation: 40778

Assuming there is at least one non-negative value in your tensor, you multiply it by the mask itself to remove excluded values in the sorting:

>>> torch.argmax(a*b)
tensor(3)

If that's not the case, you can still get away with it, using torch.where by replacing the excluded values with some value that will get ignored by the argmax (e.g. a.min()):

>>> torch.where(b, a, a.min()).argmax()
tensor(3)

Upvotes: 1

Related Questions