Reputation: 109
I'm trying to implement seq2seq with a feature to generate sampled outputs from the decoder i.e. at every step, rather than taking the argmax of the output logits from previous state, it should sample from them according to the logit distribution and use that as input for the next step.
After poking around I found the loop_function in seq2seq.py as a promising place to start. It looks like I have to write a loop function that looks like this (modified from the one in the file that extracts argmax+embedding):
def _extract_sample_and_embed(embedding, output_projection=None,
update_embedding=True):
def loop_function(prev, _):
if output_projection is not None:
prev = nn_ops.xw_plus_b(prev, output_projection[0], output_projection[1])
prev_symbol = math_ops.sample(prev) #<------- Need this op but it does not exist ?
emb_prev = embedding_ops.embedding_lookup(embedding, prev_symbol)
if not update_embedding:
emb_prev = array_ops.stop_gradient(emb_prev)
return emb_prev
return loop_function
Then i use this loop function generator in the seq2seq_embedding_with_attention model. However, the op I need that samples from a tensor of floats does not exist in tensorflow, so do I need to write my own? How do I do that?
In searching for guidance, I found that in tensorflow/tensorflow/python/ops/candidate_sampling_ops there is a reference to :
from tensorflow.python.ops import gen_candidate_sampling_ops
but I can't find this file. I'm guessing it's auto-generated from somewhere. where?
Upvotes: 4
Views: 1216
Reputation: 11
I meet the same problem today, and my solution is:
Replace line
prev_symbol = math_ops.sample(prev)
with
prev_symbol = squeeze(multinomial(prev, 1), axis=1)
The function tf.multinomial() draws samples from a multinomial distribution. It took a 2-D Tensor "logits" with shape [batch_size, num_classes], and a 0-D scalar "num_samples" as input. And output a drawn samples of shape [batch_size, num_samples].
Meanwhile, math_ops.sample() outputs samples of shape [batch_size], therefore we need tf.squeeze() to reduce dimension.
This implementation is simpler.
Upvotes: 1
Reputation: 937
Currently, you can also do it as follows, with the gumbel max trick for directe distributions:
def batch_gumbel_max_sample(a, max_gumbel_noise = 1.0):
matrix_U = -1.0*tf.log(-1.0*tf.log(tf.random_uniform(tf.shape(a),
minval = 0.0, maxval = max_gumbel_noise)))
return tf.argmax(tf.sub(a, matrix_U), dimension = 1)
There is also a discussion on Tensorflows issue tracker about this at the moment. I guess sooner or later a multinomial sample function will be added to Tensorflow. LeavesBreathe also posted a work around on that Github page, which isn't entirely correct in my opinion:
def batch_sample_with_temperature(a, temperature=1.0):
'''this function is like sample_with_temperature except it can handle batch input a of [batch_size x logits]
this function takes logits input, and produces a specific number from the array. This is all done on the gpu
because this function uses tensorflow
As you increase the temperature, you will get more diversified output but with more errors (usually gramatical if you're
doing text)
args:
Logits -- this must be a 2d array [batch_size x logits]
Temperature -- how much variance you want in output
returns:
Selected number from distribution
'''
'''
Equation can be found here: https://en.wikipedia.org/wiki/Softmax_function (under reinforcement learning)
Karpathy did it here as well: https://github.com/karpathy/char-rnn/blob/4297a9bf69726823d944ad971555e91204f12ca8/sample.lua'''
'''a is [batch_size x logits]'''
with tf.op_scope([a,temperature], "batch_sample_with_temperature"):
exponent_raised = tf.exp(tf.div(a, temperature)) #start by reduction of temperature, and get rid of negative numbers with exponent
matrix_X = tf.div(exponent_raised, tf.reduce_sum(exponent_raised, reduction_indices = 1)) #this will yield probabilities!
matrix_U = tf.random_uniform(tf.shape(a), minval = 0, maxval = 1)
final_number = tf.argmax(tf.sub(matrix_X, matrix_U), dimension = 1) #you want dimension = 1 because you are argmaxing across rows.
return final_number
Upvotes: 4