dingus
dingus

Reputation: 1001

How to invert a PyTorch Embedding?

I have an multi-task encoder/decoder model in PyTorch with a (trainable) torch.nn.Embedding embedding layer at the input.

In one particular task, I'd like to pre-train the model self-supervised (to re-construct masked input data) and use it for inference (to fill in gaps in data).

I guess for training time I can just measure loss as the distance between the input embedding and the output embedding... But for inference, how do I invert an Embedding to reconstruct the proper category/token the output corresponds to? I can't see e.g. a "nearest" function on the Embedding class...

Upvotes: 9

Views: 5950

Answers (1)

Szymon Maszke
Szymon Maszke

Reputation: 24691

You can do it quite easily:

import torch

embeddings = torch.nn.Embedding(1000, 100)
my_sample = torch.randn(1, 100)
distance = torch.norm(embeddings.weight.data - my_sample, dim=1)
nearest = torch.argmin(distance)

Assuming you have 1000 tokens with 100 dimensionality this would return nearest embedding based on euclidean distance. You could also use other metrics in similar manner.

Upvotes: 9

Related Questions