Exploring Beam Search Decoding for Text Generation in Python
Written on
Chapter 1: Introduction to Text Generation Techniques
In our previous discussion, we explored a prevalent approach to text generation known as Greedy Search decoding. This method focuses on selecting the word with the highest probability at every timestep.
However, Beam Search decoding takes a different approach. Instead of merely considering the individual probabilities of each token, it evaluates all potential extensions of each token. The optimal sequence is then determined based on their log probabilities.
To illustrate this, consider the probabilities of the phrase "Pancakes looks at timestep 1". For instance:
- "Pancakes looks so" = log(0.2) + log(0.7) = -1.9
- "Pancakes looks fluffy" = log(0.2) + log(0.3) = -2.8
Now, let’s implement functions to compute the log probability for an entire sentence.
import torch.nn.functional as F
def log_probability_single(logits, labels):
logp = F.log_softmax(logits, dim=-1)
logp_label = torch.gather(logp, 2, labels.unsqueeze(2)).squeeze(-1)
return logp_label
def sentence_logprob(model, labels, input_len=0):
with torch.no_grad():
result = model(labels)
log_probability = log_probability_single(result.logits[:, :-1, :],
labels[:, 1:])sentence_log_prob = torch.sum(log_probability[:, input_len:])
return sentence_log_prob.cpu().numpy()
Next, we can apply this to the output generated from the Greedy Search method and compute the log probability of the produced sequence. For demonstration, let’s use a brief excerpt from Haruki Murakami's novel, 1Q84.
input_sentence = "A love story, a mystery, a fantasy, a novel of self-discovery, a dystopia to rival George Orwell's — 1Q84 is Haruki Murakami's most ambitious undertaking yet: an instant best seller in his native Japan, and a tremendous feat of imagination from one of our most revered contemporary writers."
max_sequence = 100
input_ids = tokenizer(input_sentence, return_tensors='pt')['input_ids'].to(device)
output = model.generate(input_ids, max_length=max_sequence, do_sample=False)
greedy_search_output = sentence_logprob(model, output, input_len=len(input_ids[0]))
print(tokenizer.decode(output[0]))
print(f"nlog_prob: {greedy_search_output:.2f}")
From our results, the generated sequence has a log probability of -52.31.
Now, let’s apply Beam Search decoding to compare the log probability scores of sequences generated using this method. We can utilize the num_beams() function to enhance our results, where a greater number of beams typically yields superior outcomes. Additionally, we apply the n-gram penalty parameter no_repeat_ngram_size to mitigate repetitive sequences in the output.
beam_search_output = model.generate(input_ids,
max_length=max_sequence,
num_beams=5,
do_sample=False,
no_repeat_ngram_size=2)
beam_search_log_prob = sentence_logprob(model, beam_search_output, input_len=len(input_ids[0]))
print(tokenizer.decode(beam_search_output[0]))
print(f"nlog_prob: {beam_search_log_prob:.2f}")
The results from Beam Search decoding show a significant improvement in coherence and structure.
To further enrich your understanding, check out the following videos:
How-to Decode Outputs From NLP Models (Python) - YouTube: A practical guide on interpreting model outputs in NLP.
UMass CS685 S23 (Advanced NLP) #10: Decoding from language models - YouTube: An advanced look at decoding techniques used in NLP.
In conclusion, this post has provided insights into Beam Search decoding, showcasing its effectiveness in generating coherent text based on specific inputs. If you're eager to delve deeper, consider exploring more resources on this topic and supporting writers on platforms like Medium.