Reputation: 495
From past few days, I have been trying to figure out the flow of execution in the code https://github.com/tensorflow/models/blob/master/tutorials/embedding/word2vec.py#L28 .
I understood the logic behind negative sampling and loss function, but I am getting so confused about the flow of execution inside the train function, especially when it comes to _train_thread_body
function. I am so confused about the while and if loops ( what is the impact ) and the concurrency related parts. It would be great, if someone can give a decent explanation, before down-voting this.
Upvotes: 1
Views: 107
Reputation: 53758
This sample code is called "Multi-threaded word2vec mini-batched skip-gram model", that's why it uses several independent threads for training. Word2Vec can be trained with a single thread as well, but this tutorial shows that word2vec is faster to compute when done in parallel.
The input, label and epoch tensors are provided by the native word2vec.skipgram_word2vec
function, which is implemented in tutorials/embedding/word2vec_kernels.cc
file. There you can see that current_epoch
is a tensor updated once the whole corpus of sentences is processed.
The method you're asking about is actually pretty simple:
def _train_thread_body(self):
initial_epoch, = self._session.run([self._epoch])
while True:
_, epoch = self._session.run([self._train, self._epoch])
if epoch != initial_epoch:
break
First, it computes the current epoch, then it invokes the training until the epoch
is increased. This means that all of threads running this method will make exactly one epoch of training. Each thread is doing one step at a time in parallel with others.
self._train
is an op that optimizes the loss function (see optimize
method), which is computed from current examples
and labels
(see build_graph
method). The exact value of these tensors is in native code again, namely in NextExample
. Essentially, each call of word2vec.skipgram_word2vec
extracts the set of examples and labels, which form the input to the optimization function. Hope, it makes it clearer now.
By the way, this model uses NCE loss in training, not negative sampling.
Upvotes: 1