Sarath R Nair
Sarath R Nair

Reputation: 495

Flow of execution in word2vec tensorflow

From past few days, I have been trying to figure out the flow of execution in the code https://github.com/tensorflow/models/blob/master/tutorials/embedding/word2vec.py#L28 .

I understood the logic behind negative sampling and loss function, but I am getting so confused about the flow of execution inside the train function, especially when it comes to _train_thread_body function. I am so confused about the while and if loops ( what is the impact ) and the concurrency related parts. It would be great, if someone can give a decent explanation, before down-voting this.

Upvotes: 1

Views: 107

Answers (1)

Maxim
Maxim

Reputation: 53758

This sample code is called "Multi-threaded word2vec mini-batched skip-gram model", that's why it uses several independent threads for training. Word2Vec can be trained with a single thread as well, but this tutorial shows that word2vec is faster to compute when done in parallel.

The input, label and epoch tensors are provided by the native word2vec.skipgram_word2vec function, which is implemented in tutorials/embedding/word2vec_kernels.cc file. There you can see that current_epoch is a tensor updated once the whole corpus of sentences is processed.

The method you're asking about is actually pretty simple:

def _train_thread_body(self):
  initial_epoch, = self._session.run([self._epoch])
  while True:
    _, epoch = self._session.run([self._train, self._epoch])
    if epoch != initial_epoch:
      break

First, it computes the current epoch, then it invokes the training until the epoch is increased. This means that all of threads running this method will make exactly one epoch of training. Each thread is doing one step at a time in parallel with others.

self._train is an op that optimizes the loss function (see optimize method), which is computed from current examples and labels (see build_graph method). The exact value of these tensors is in native code again, namely in NextExample. Essentially, each call of word2vec.skipgram_word2vec extracts the set of examples and labels, which form the input to the optimization function. Hope, it makes it clearer now.

By the way, this model uses NCE loss in training, not negative sampling.

Upvotes: 1

Related Questions