Reputation: 2853
Given a batch of text sequences, the same is converted into a tensor with each word represented using word embeddings or vectors (of 300 dimensions). I need to selectively replace vectors for certain specific words with a new set of embeddings. Further, this replacement will occur only for not all occurrences of the specific word, but only randomly. Currently, I have the following code to achieve this. It traverses through every word using 2 for loops, check if the word is in a specified list, splIndices
. Then it checks if the word needs to be replaced or not, based on T or F value in selected_
.
But could this be done in a more efficient manner?
The below code may not be an MWE, but I have tried to simplify the code by removing the specifics, so as to focus on the problem. Please ignore the semantics or purpose of the code as it may not have been appropriately represented in this snippet. The question is about improving performance.
splIndices = [45, 62, 2983, 456, 762] # vocabulary indices which needs to be replaced
splFreqs = 2000 # assuming the words in splIndices occurs 2000 times
selected_ = Torch.Tensor(2000).uniform_(0, 1) > 0.2 # Tensor with 20% of the entries True
replIndexCtr = 0 # counter for selected_
# Dictionary with vectors to be replaced. This is a dummy function.
# Original function depends on some property of the word
diffVector = {45: Torch.Tensor(300).uniform_(0, 1), ...... 762: Torch.Tensor(300).uniform_(0, 1) }
embeding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
tempVals = x # shape [32, 41] - batch of 32 sequences with 41 words each
x = embeding(x) # shape [32, 41, 300] - the sequence now has replaced vocab indices with embeddings
# iterate through batch for sequences
for i, item in enumerate(x):
# iterate sequences for words
for j, stuff in enumerate(item):
if tempVals[i][j].item() in splIndices:
if self.selected_[replIndexCtr] == True:
x[i,j] = diffVector[tempVals[i][j].item()]
replIndexCtr += 1
Upvotes: 2
Views: 1280
Reputation: 320
It could be vectorized in a following way:
import torch
import torch.nn as nn
import torch.nn.functional as F
batch_size, sentence_size, vocab_size, emb_size = 3, 2, 15, 1
# Make certain bias as a marker of embedding
embedder_1 = nn.Linear(vocab_size, emb_size)
embedder_1.weight.data.fill_(0)
embedder_1.bias.data.fill_(200)
embedder_2 = nn.Linear(vocab_size, emb_size)
embedder_2.weight.data.fill_(0)
embedder_2.bias.data.fill_(404)
# Here are the indices of words which need different embdedding
replace_list = [3, 5, 7, 9]
# Make a binary mask highlighing special words' indices
mask = torch.zeros(batch_size, sentence_size, vocab_size)
mask[..., replace_list] = 1
# Make random dataset
data_indices = torch.randint(0, vocab_size, (batch_size, sentence_size))
data_onehot = F.one_hot(data_indices, vocab_size)
# Check if onehot of a word collides with replace mask
replace_mask = mask.long() * data_onehot
replace_mask = torch.sum(replace_mask, dim=-1).byte() # byte() is critical here
data_emb = torch.empty(batch_size, sentence_size, emb_size)
# Fill default embeddings
data_emb[1-replace_mask] = embedder_1(data_onehot[1-replace_mask].float())
if torch.max(replace_mask) != 0: # If not all zeros
# Fill special embeddings
data_emb[replace_mask] = embedder_2(data_onehot[replace_mask].float())
print(data_indices)
print(replace_mask)
print(data_emb.squeeze(-1).int())
Here is an example of a possible output:
# Word indices
tensor([[ 6, 9],
[ 5, 10],
[ 4, 11]])
# Embedding replacement mask
tensor([[0, 1],
[1, 0],
[0, 0]], dtype=torch.uint8)
# Resulting replacement
tensor([[200, 404],
[404, 200],
[200, 200]], dtype=torch.int32)
Upvotes: 1