niloofar zarif
niloofar zarif

Reputation: 19

How embedding_bag exactly works in PyTorch

in PyTorch, torch.nn.functional.embedding_bag seems to be the main function responsible for doing the real job of embedding lookup. On PyTorch's documentation, it has been mentioned that embedding_bag does its job > without instantiating the intermediate embeddings. What does that exactly mean? Does this mean for example when the mode is "sum" it does in-place summation? or it just means that no additional Tensors will be produced when calling embedding_bag but still from the system's point of view all the intermediate row-vectors are already fetched into the processor to be used for calculating the final Tensor?

Upvotes: 1

Views: 5733

Answers (1)

Kurt Mohler
Kurt Mohler

Reputation: 451

In the simplest case, torch.nn.functional.embedding_bag is conceptually a two step process. The first step is to create an embedding and the second step is to reduce (sum/mean/max, according to the "mode" argument) the embedding output across dimension 0. So you can get the same result that embedding_bag gives by calling torch.nn.functional.embedding, followed by torch.sum/mean/max. In the following example, embedding_bag_res and embedding_mean_res are equal.

>>> weight = torch.randn(3, 4)
>>> weight
tensor([[ 0.3987,  1.6173,  0.4912,  1.5001],
        [ 0.2418,  1.5810, -1.3191,  0.0081],
        [ 0.0931,  0.4102,  0.3003,  0.2288]])
>>> indices = torch.tensor([2, 1])
>>> embedding_res = torch.nn.functional.embedding(indices, weight)
>>> embedding_res
tensor([[ 0.0931,  0.4102,  0.3003,  0.2288],
        [ 0.2418,  1.5810, -1.3191,  0.0081]])
>>> embedding_mean_res = embedding_res.mean(dim=0, keepdim=True)
>>> embedding_mean_res
tensor([[ 0.1674,  0.9956, -0.5094,  0.1185]])
>>> embedding_bag_res = torch.nn.functional.embedding_bag(indices, weight, torch.tensor([0]), mode='mean')
>>> embedding_bag_res
tensor([[ 0.1674,  0.9956, -0.5094,  0.1185]])

However, the conceptual two step process does not reflect how it's actually implemented. Since embedding_bag does not need to return the intermediate result, it doesn't actually generate a Tensor object for the embedding. It just goes straight to computing the reduction, pulling in the appropriate data from the weight argument according to the indices in the input argument. Avoiding the creation of the embedding Tensor allows for better performance.

So the answer to your question (if I understand it correctly)

it just means that no additional Tensors will be produced when calling embedding_bag but still from the system's point of view all the intermediate row-vectors are already fetched into the processor to be used for calculating the final Tensor?

is yes.

Upvotes: 6

Related Questions