Reputation: 854
I'm using the huggingface library to generate text using the pre-trained distilgpt2 model. In particular, I am making use of the beam_search function, as I would like to include a LogitsProcessorList (which you can't use with the generate function).
The relevant portion of my code looks like this:
beam_scorer = BeamSearchScorer(
batch_size=btchsze,
max_length=15, # not sure why lengths under 20 fail
num_beams=num_seq,
device=model.device,
)
j = input_ids.tile((num_seq*btchsze,1))
next_output = model.beam_search(
j,
beam_scorer,
eos_token_id=tokenizer.encode('.')[0],
logits_processor=logits_processor
)
However, the beam_search function throws this error when I try to generate using a max_length of less than 20:
~/anaconda3/envs/techtweets37/lib/python3.7/site-packages/transformers-4.4.2-py3.8.egg/transformers/generation_beam_search.py in finalize(self, input_ids, final_beam_scores, final_beam_tokens, final_beam_indices, pad_token_id, eos_token_id)
326 # fill with hypotheses and eos_token_id if the latter fits in
327 for i, hypo in enumerate(best):
--> 328 decoded[i, : sent_lengths[i]] = hypo
329 if sent_lengths[i] < self.max_length:
330 decoded[i, sent_lengths[i]] = eos_token_id
RuntimeError: The expanded size of the tensor (15) must match the existing size (20) at non-singleton dimension 0. Target sizes: [15]. Tensor sizes: [20]
I can't seem to figure out where 20 is coming from: it's the same even if the input length is longer or shorter, even if I use a different batch size or number of beams. There's nothing I've defined as length 20, nor can I find any default. The max length of the sequence does effect the results of the beam search, so I'd like to figure this out and be able to set a shorter max length.
Upvotes: 2
Views: 1188
Reputation: 854
This is a known issue in the hugging face library:
https://github.com/huggingface/transformers/issues/11040
Basically, the beam scorer isn't using the max_length
passed to it, but the max_length
of the model.
For now, the fix is to set model.config.max_length
to the desired max length.
Upvotes: 3