Burrito
Burrito

Reputation: 1634

How does Spark Word2Vec merge each partition's results?

Increasing numPartitions for Spark's Word2Vec makes it faster but less accurate since it fits each partition separately, reducing the context available for each word, before merging the results.

How exactly does it merge the results from multiple partitions? Is it just an average of the vectors? Looking to better understand how this affects the accuracy.

Looking at the source code, I think the merging is happening here:

val synAgg = partial.reduceByKey { case (v1, v2) =>
          blas.saxpy(vectorSize, 1.0f, v2, 1, v1, 1)
          v1
      }.collect()

Which looks like just a vector sum (effectively an average). partial comes from:

val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter =>
      // Each sentence will map to 0 or more Array[Int]
      sentenceIter.flatMap { sentence =>
        // Sentence of words, some of which map to a word index
        val wordIndexes = sentence.flatMap(bcVocabHash.value.get)
        // break wordIndexes into trunks of maxSentenceLength when has more
        wordIndexes.grouped(maxSentenceLength).map(_.toArray)
      }
    }
val newSentences = sentences.repartition(numPartitions).cache()
val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
// ... long calculation (skip-gram training, etc.)
}

But I'm not a Word2Vec/Spark ML/Scala expert, so hoping someone more knowledgeable can verify.

Upvotes: 2

Views: 189

Answers (1)

philipmathieu
philipmathieu

Reputation: 13

saxpy is a function from BLAS (a widely-used linear algebra library) that computes "scalar times a vector plus a vector." In this case, the scalar is 1.0, so the function is just summing the results. The increment terms (1) let the function know that there is not spacing between the elements in memory, which allows it to compute the result more efficiently. In more recent versions of pyspark, an additional normalization term is used to prevent overflows (see here).

You are correct that this is effectively an average of the vectors. If you think of Word2Vec as a neural network, this would be similar to doing a batch normalization step, with the batch size being the number of rows in each partition of data. Since that is a very large number, it could prevent you from reaching the absolute optimal result (that is, the set of "perfect" embeddings that minimize the Word2Vec loss function), but this may or may not be a real issue depending on your application and dataset.

Upvotes: 0

Related Questions