nmiculinic
nmiculinic

Reputation: 2474

Weighted random tensor select in tensorflow

I have a list of tensors and list representing their probability mass function. How can I each session run tell tensorflow to randomly pick one tensor according to probability mass function.

I see few possible ways to do that:

One is packing list of tensors in rank one higher, and select one with slice & squeeze based on tensorflow variable I'm going to assign correct index. What would be performance penalty for this approach? Would tensorflow evaluate other, non-needed tensors?

Another is using tf.case in similar fashion as before with me picking one tensor out of many. Same question -> What's the performance penalty since I plan on having quite a few(~100s) conditional statements per one graph run.

Is there any better way of doing this?

Upvotes: 1

Views: 2196

Answers (1)

Olivier Moindrot
Olivier Moindrot

Reputation: 28218

I think you should use tf.multinomial(logits, num_samples).

Say you have:

  • a batch of tensors of shape [batch_size, num_features]
  • a probability distribution of shape [batch_size]

You want to output:

  • 1 example from the batch of tensors, of shape [1, num_features]

batch_tensors = tf.constant([[0., 1., 2.], [3., 4., 5.]])  # shape [batch_size, num_features]
probabilities = tf.constant([0.7, 0.3])  # shape [batch_size]

# we need to convert probabilities to log_probabilities and reshape it to [1, batch_size]
rescaled_probas = tf.expand_dims(tf.log(probabilities), 0)  # shape [1, batch_size]

# We can now draw one example from the distribution (we could draw more)
indice = tf.multinomial(rescaled_probas, num_samples=1)

output = tf.gather(batch_tensors, tf.squeeze(indice, [0]))

What's the performance penalty since I plan on having quite a few(~100s) conditional statements per one graph run?

If you want to do multiple draws, you should do it in one run by increasing the parameter num_samples. You can then gather these num_samples examples in one run with tf.gather.

Upvotes: 6

Related Questions