patti_jane
patti_jane

Reputation: 3841

BERT embeddings in batches

I am following this post to extract embeddings for sentences and for a single sentence the steps are described as follows:

    text = "After stealing money from the bank vault, the bank robber was seen " \
           "fishing on the Mississippi river bank."
    
    # Add the special tokens. 
    marked_text = "[CLS] " + text + " [SEP]"
    
    # Split the sentence into tokens. 
    tokenized_text = tokenizer.tokenize(marked_text)
    
    # Mark each of the 22 tokens as belonging to sentence "1". 
    segments_ids = [1] * len(tokenized_text)
    
    # Convert inputs to PyTorch tensors 
    tokens_tensor = torch.tensor([indexed_tokens]) 
    segments_tensors = torch.tensor([segments_ids])
    
    # Load pre-trained model (weights) 
    model = BertModel.from_pretrained('bert-base-uncased',
                                 output_hidden_states = True,
                                      )
    
    # Put the model in "evaluation" mode, meaning feed-forward operation. 
    model.eval()
    
    with torch.no_grad():
        outputs = model(tokens_tensor, segments_tensors)
        hidden_states = outputs[2]

And I want to do this for a batch of sequences. Here is my example code:

seql = ['this is an example', 'today was sunny and', 'today was']
encoded = [tokenizer.encode(seq, max_length=5, pad_to_max_length=True) for seq in seql]

encoded
[[2, 2511, 1840, 3251, 3], 
 [2, 1663, 2541, 1957, 3], 
 [2, 1663, 2541, 3, 0]]

But since I'm working with batches, sequences need to have same length. So I introduce a padding token (3rd sentence) which confuses me about several points:

Upvotes: 1

Views: 4317

Answers (1)

laifi
laifi

Reputation: 96

You could do all the work you need using one function ( padding,truncation)

encode_plus

check the parameters: the docs

The same you could do with a list of sequences

batch_encode_plus

docs

Upvotes: 1

Related Questions