Reputation: 2079
I'm trying to implement a beam search decoding strategy in a text generation model. This is the function that I am using to decode the output probabilities.
def beam_search_decoder(data, k):
sequences = [[list(), 0.0]]
# walk over each step in sequence
for row in data:
all_candidates = list()
for i in range(len(sequences)):
seq, score = sequences[i]
for j in range(len(row)):
candidate = [seq + [j], score - torch.log(row[j])]
all_candidates.append(candidate)
# sort candidates by score
ordered = sorted(all_candidates, key=lambda tup:tup[1])
sequences = ordered[:k]
return sequences
Now you can see this function is implemented with batch_size 1 in mind. Adding another loop for batch size would make the algorithm O(n^4)
. It is slow as it is now. Is there any way to improve the speed of this function. My model output is usually of the size (32, 150, 9907)
which follows the format (batch_size, max_len, vocab_size)
Upvotes: 6
Views: 9808
Reputation: 639
Based on the version proposed by 防暴队大盾, I decided to implement a version of the beam-search algorithm that does not overlook sequences that share initial tokens. This is done by retrieving correct indices from the indices of the flatten array
def beam_search(prediction, k=10):
batch_size, seq_length, vocab_size = prediction.shape
log_prob, indices = prediction[:, 0, :].topk(k, sorted=True)
indices = indices.unsqueeze(-1)
for n1 in range(1, seq_length):
log_prob_temp = log_prob.unsqueeze(-1) + prediction[:, n1, :].unsqueeze(1).repeat(1, k, 1)
log_prob, index_temp = log_prob_temp.view(batch_size, -1).topk(k, sorted=True)
idx_begin = index_temp // vocab_size # retrieve index of start sequence
idx_concat = index_temp % vocab_size # retrieve index of new token
new_indices = torch.zeros((batch_size, k, n1+1), dtype=torch.int64)
for n2 in range(batch_size):
new_indices[n2, :, :-1] = indices[n2][idx_begin[n2]]
new_indices[n2, :, -1] = idx_concat[n2]
indices = new_indices
return indices, log_prob
This version assumes that prediction
corresponds to the cross-entropy scores, not the probability. Therefore no need to take the log here.
If someone knows how to avoid the inner-most loop with some fancy indexing, one can probably make this even faster.
Upvotes: 1
Reputation: 1
You can use this library
https://pypi.org/project/pytorch-beam-search/
It implements Beam Search, Greedy Search and sampling for PyTorch sequence models.
The following snippet implements a Transformer seq2seq model and uses it to generate predictions.
#pip install pytorch-beam-search
from pytorch_beam_search import seq2seq
# Create vocabularies
# Tokenize the way you need
source = [list("abcdefghijkl"), list("mnopqrstwxyz")]
target = [list("ABCDEFGHIJKL"), list("MNOPQRSTWXYZ")]
# An Index object represents a mapping from the vocabulary to
# to integers (indices) to feed into the models
source_index = seq2seq.Index(source)
target_index = seq2seq.Index(target)
# Create tensors
X = source_index.text2tensor(source)
Y = target_index.text2tensor(target)
# X.shape == (n_source_examples, len_source_examples) == (2, 11)
# Y.shape == (n_target_examples, len_target_examples) == (2, 12)
# Create and train the model
model = seq2seq.Transformer(source_index, target_index) # just a PyTorch model
model.fit(X, Y, epochs = 100) # basic method included
# Generate new predictions
new_source = [list("new first in"), list("new second in")]
new_target = [list("new first out"), list("new second out")]
X_new = source_index.text2tensor(new_source)
Y_new = target_index.text2tensor(new_target)
loss, error_rate = model.evaluate(X_new, Y_new) # basic method included
predictions, log_probabilities = seq2seq.beam_search(model, X_new)
output = [target_index.tensor2text(p) for p in predictions]
output
Upvotes: -2
Reputation: 91
Below is my implementation, which may be a little bit faster than the for loop implementation.
import torch
def beam_search_decoder(post, k):
"""Beam Search Decoder
Parameters:
post(Tensor) – the posterior of network.
k(int) – beam size of decoder.
Outputs:
indices(Tensor) – a beam of index sequence.
log_prob(Tensor) – a beam of log likelihood of sequence.
Shape:
post: (batch_size, seq_length, vocab_size).
indices: (batch_size, beam_size, seq_length).
log_prob: (batch_size, beam_size).
Examples:
>>> post = torch.softmax(torch.randn([32, 20, 1000]), -1)
>>> indices, log_prob = beam_search_decoder(post, 3)
"""
batch_size, seq_length, _ = post.shape
log_post = post.log()
log_prob, indices = log_post[:, 0, :].topk(k, sorted=True)
indices = indices.unsqueeze(-1)
for i in range(1, seq_length):
log_prob = log_prob.unsqueeze(-1) + log_post[:, i, :].unsqueeze(1).repeat(1, k, 1)
log_prob, index = log_prob.view(batch_size, -1).topk(k, sorted=True)
indices = torch.cat([indices, index.unsqueeze(-1)], dim=-1)
return indices, log_prob
Upvotes: 9