Amrith Krishna
Amrith Krishna

Reputation: 2853

Efficient way of selectively replacing vectors from a tensor in pytorch

Given a batch of text sequences, the same is converted into a tensor with each word represented using word embeddings or vectors (of 300 dimensions). I need to selectively replace vectors for certain specific words with a new set of embeddings. Further, this replacement will occur only for not all occurrences of the specific word, but only randomly. Currently, I have the following code to achieve this. It traverses through every word using 2 for loops, check if the word is in a specified list, splIndices. Then it checks if the word needs to be replaced or not, based on T or F value in selected_.

But could this be done in a more efficient manner?

The below code may not be an MWE, but I have tried to simplify the code by removing the specifics, so as to focus on the problem. Please ignore the semantics or purpose of the code as it may not have been appropriately represented in this snippet. The question is about improving performance.


splIndices = [45, 62, 2983, 456, 762]  # vocabulary indices which needs to be replaced
splFreqs = 2000  # assuming the words in splIndices occurs 2000 times
selected_ = Torch.Tensor(2000).uniform_(0, 1) > 0.2  # Tensor with 20% of the entries True
replIndexCtr = 0  # counter for selected_

# Dictionary with vectors to be replaced. This is a dummy function.
# Original function depends on some property of the word
diffVector = {45: Torch.Tensor(300).uniform_(0, 1), ...... 762: Torch.Tensor(300).uniform_(0, 1) } 

embeding = nn.Embedding.from_pretrained(embedding_matrix, freeze=False)
tempVals = x  # shape [32, 41] - batch of 32 sequences with 41 words each
x = embeding(x) # shape [32, 41, 300] - the sequence now has replaced vocab indices with embeddings

# iterate through batch for sequences
for i, item in enumerate(x):
    # iterate sequences for words
    for j, stuff in enumerate(item):
        if tempVals[i][j].item() in splIndices: 
            if self.selected_[replIndexCtr] == True:                   
                x[i,j] = diffVector[tempVals[i][j].item()]
                replIndexCtr += 1


Upvotes: 2

Views: 1280

Answers (1)

mcstarioni
mcstarioni

Reputation: 320

It could be vectorized in a following way:

import torch
import torch.nn as nn
import torch.nn.functional as F

batch_size, sentence_size, vocab_size, emb_size = 3, 2, 15, 1

# Make certain bias as a marker of embedding 
embedder_1 = nn.Linear(vocab_size, emb_size)
embedder_1.weight.data.fill_(0)
embedder_1.bias.data.fill_(200)

embedder_2 = nn.Linear(vocab_size, emb_size)
embedder_2.weight.data.fill_(0)
embedder_2.bias.data.fill_(404)

# Here are the indices of words which need different embdedding
replace_list = [3, 5, 7, 9] 

# Make a binary mask highlighing special words' indices
mask = torch.zeros(batch_size, sentence_size, vocab_size)
mask[..., replace_list] = 1

# Make random dataset
data_indices = torch.randint(0, vocab_size, (batch_size, sentence_size))
data_onehot = F.one_hot(data_indices, vocab_size)

# Check if onehot of a word collides with replace mask 
replace_mask = mask.long() * data_onehot
replace_mask = torch.sum(replace_mask, dim=-1).byte() # byte() is critical here

data_emb = torch.empty(batch_size, sentence_size, emb_size)

# Fill default embeddings
data_emb[1-replace_mask] = embedder_1(data_onehot[1-replace_mask].float())
if torch.max(replace_mask) != 0: # If not all zeros
    # Fill special embeddings
    data_emb[replace_mask] = embedder_2(data_onehot[replace_mask].float())

print(data_indices)
print(replace_mask)
print(data_emb.squeeze(-1).int())

Here is an example of a possible output:

# Word indices
tensor([[ 6,  9],
        [ 5, 10],
        [ 4, 11]])
# Embedding replacement mask
tensor([[0, 1],
        [1, 0],
        [0, 0]], dtype=torch.uint8)
# Resulting replacement
tensor([[200, 404],
        [404, 200],
        [200, 200]], dtype=torch.int32)

Upvotes: 1

Related Questions