Isamu Isozaki
Isamu Isozaki

Reputation: 175

Tensorflow: How to use tf.reduce_prod only across selected indices in batched data

I'm quite sure that there is a very easy way to do this but I was not able to find it so far

Problem

I have two tensors.

One has a probability for each action across a batch. Thus, it has size N times M where N is the batch size and M is the number of possible actions. This is called action_probs

As my agent can do multiple actions, my other tensor has ones for actions that had been chosen and zero otherwise. This has the same dimension as action_probs and is called action.

I want to output, for each batch, the probability of choosing those actions. The probability for each batch is the product of the probabilities of the picked indices.

Attempted failed methods

I tried to first create a mask and then conduct tf.reduce_prob over the whole thing like as follows

ones = tf.ones_like(action)
mask = tf.equal(action, ones)
action_probs_masked = tf.boolean_mask(action_probs, mask)
picked_action_probs = tf.reduce_prod(action_probs_masked, axis = 1)

However, as the boolean mask does not return a 2d array, this was not successful. I tried next to loop over each index using tf.while_loop but the code became too complicated and filled with bugs so I was not able to continue.

Example

Input

action = [[0.0,0.0,1.0,1.0], [1.0,0.0,1.0,1.0]]
action_probs = [[0.9, 0.8, 0.4, 0.5], [0.5, 0.7, 0.6, 0.4]]

Output

output = [0.2,0.12] 

which is given by [0.4*0.5, 0.5*0.6*0.4]

If anything is unclear please tell me in the comments!

Upvotes: 0

Views: 406

Answers (1)

giser_yugang
giser_yugang

Reputation: 6176

You can mask action_probs to 1 when action=0.

import tensorflow as tf

action = tf.constant([[0.0,0.0,1.0,1.0], [1.0,0.0,1.0,1.0]],dtype=tf.float32)
action_probs  = tf.constant([[0.9, 0.8, 0.4, 0.5], [0.5, 0.7, 0.6, 0.4]],dtype=tf.float32)

action_probs_mask = tf.where(tf.equal(action,1),action_probs,tf.ones_like(action_probs))
result = tf.reduce_prod(action_probs_mask,axis=1)

with tf.Session() as sess:
    print(sess.run(result))

[0.2  0.12]

Upvotes: 1

Related Questions