Reputation: 747
I am new to Pytorch and RNN. I am learning how to use RNN to predict numbers as a tutorial from the video: https://www.youtube.com/watch?v=MKA6v99uYKY
In his code, he use python 3 and do the decode like:
out_unembedded = out.view(-1, hidden_size) @ embedding.weight.transpose(0,1)
I am using Python 2 and try the code:
out_unembedded = out.view(-1, hidden_size).dot( embedding.weight.transpose(0,1))
But it seems not right, Then I try to decode like this:
import torch
import torch.nn as nn
from torch.autograd import Variable
word2id = {'hello': 0, 'world': 1, 'I': 2, 'am': 3,'writing': 4,'pytorch': 5}
embeds = nn.Embedding(6, 3)
word_embed = embeds(Variable(torch.LongTensor([word2id['am']])))
id2word = {v: k for k, v in word2id.iteritems()}
index = 0
for row in embeds.weight.split(1):
if(torch.min( torch.eq(row.data,word_embed.data) ) == 1):
print index
print id2word[index]
index+=1
Is there a more professional way to do this? Thanks!
------------ UPDATE ------------
I find the correct way to substitute @ in Python 2:
out_unembedded = torch.mm( embedded_output.view(-1, hidden_size),embedding.weight.transpose(0, 1))
Upvotes: 1
Views: 1790
Reputation: 747
I finally figure out the problem. The two decode methods are different.
The first one use
@
to do the dot product. Instead of searching the exact decoding, it calculates the cosine similarity by dot product and find the most similar word. The value after dot product means the similarity between the target and the word with such index. The equation is:
The second method which build a hash map is to find the index using the exact encoding.
Upvotes: 1