Pingjiang Li
Pingjiang Li

Reputation: 747

How to decode an embedding in PyTorch efficiently?

I am new to Pytorch and RNN. I am learning how to use RNN to predict numbers as a tutorial from the video: https://www.youtube.com/watch?v=MKA6v99uYKY

In his code, he use python 3 and do the decode like:

out_unembedded = out.view(-1, hidden_size) @ embedding.weight.transpose(0,1)

I am using Python 2 and try the code:

out_unembedded = out.view(-1, hidden_size).dot( embedding.weight.transpose(0,1))

But it seems not right, Then I try to decode like this:

import torch
import torch.nn as nn
from torch.autograd import Variable

word2id = {'hello': 0, 'world': 1, 'I': 2, 'am': 3,'writing': 4,'pytorch': 5}
embeds = nn.Embedding(6, 3)
word_embed = embeds(Variable(torch.LongTensor([word2id['am']])))

id2word = {v: k for k, v in word2id.iteritems()}
index = 0
for row in embeds.weight.split(1):
    if(torch.min( torch.eq(row.data,word_embed.data) ) == 1):
        print index
        print id2word[index]
    index+=1

Is there a more professional way to do this? Thanks!

------------ UPDATE ------------

I find the correct way to substitute @ in Python 2:

out_unembedded = torch.mm( embedded_output.view(-1, hidden_size),embedding.weight.transpose(0, 1))

Upvotes: 1

Views: 1790

Answers (1)

Pingjiang Li
Pingjiang Li

Reputation: 747

I finally figure out the problem. The two decode methods are different.

The first one use

@

to do the dot product. Instead of searching the exact decoding, it calculates the cosine similarity by dot product and find the most similar word. The value after dot product means the similarity between the target and the word with such index. The equation is:

enter image description here

The second method which build a hash map is to find the index using the exact encoding.

Upvotes: 1

Related Questions