Reputation: 19
in PyTorch, torch.nn.functional.embedding_bag seems to be the main function responsible for doing the real job of embedding lookup. On PyTorch's documentation, it has been mentioned that embedding_bag does its job > without instantiating the intermediate embeddings. What does that exactly mean? Does this mean for example when the mode is "sum" it does in-place summation? or it just means that no additional Tensors will be produced when calling embedding_bag but still from the system's point of view all the intermediate row-vectors are already fetched into the processor to be used for calculating the final Tensor?
Upvotes: 1
Views: 5733
Reputation: 451
In the simplest case, torch.nn.functional.embedding_bag
is conceptually a two step process. The first step is to create an embedding and the second step is to reduce (sum/mean/max, according to the "mode" argument) the embedding output across dimension 0. So you can get the same result that embedding_bag gives by calling torch.nn.functional.embedding
, followed by torch.sum/mean/max
. In the following example, embedding_bag_res
and embedding_mean_res
are equal.
>>> weight = torch.randn(3, 4)
>>> weight
tensor([[ 0.3987, 1.6173, 0.4912, 1.5001],
[ 0.2418, 1.5810, -1.3191, 0.0081],
[ 0.0931, 0.4102, 0.3003, 0.2288]])
>>> indices = torch.tensor([2, 1])
>>> embedding_res = torch.nn.functional.embedding(indices, weight)
>>> embedding_res
tensor([[ 0.0931, 0.4102, 0.3003, 0.2288],
[ 0.2418, 1.5810, -1.3191, 0.0081]])
>>> embedding_mean_res = embedding_res.mean(dim=0, keepdim=True)
>>> embedding_mean_res
tensor([[ 0.1674, 0.9956, -0.5094, 0.1185]])
>>> embedding_bag_res = torch.nn.functional.embedding_bag(indices, weight, torch.tensor([0]), mode='mean')
>>> embedding_bag_res
tensor([[ 0.1674, 0.9956, -0.5094, 0.1185]])
However, the conceptual two step process does not reflect how it's actually implemented. Since embedding_bag
does not need to return the intermediate result, it doesn't actually generate a Tensor object for the embedding. It just goes straight to computing the reduction, pulling in the appropriate data from the weight
argument according to the indices in the input
argument. Avoiding the creation of the embedding Tensor allows for better performance.
So the answer to your question (if I understand it correctly)
it just means that no additional Tensors will be produced when calling embedding_bag but still from the system's point of view all the intermediate row-vectors are already fetched into the processor to be used for calculating the final Tensor?
is yes.
Upvotes: 6