Ahmad
Ahmad

Reputation: 9668

How to get top generated text of T5 transformers?

I trained a T5 transformer using simpletransformers library.

Here is a code to get the predictions:

pred_values = model.predict(input_values)

However, it just returns the top or greedy prediction, how can I get 10 top results?

Upvotes: 3

Views: 5094

Answers (1)

Ahmad
Ahmad

Reputation: 9668

The required parameter is num_return_sequences, which shows the number of samples to generate. However, you should also set a number for beam search if you want to use a beam search algorithm.

model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2

Alternatively, you can use top_k or top_p to generate and select among top samples, in these cases, you must set do_sample to True. For more information about the parameters refer to [1] and [2], which is a detailed explanation.

model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2

[1] https://simpletransformers.ai/docs/t5-model/

[2] https://huggingface.co/blog/how-to-generate

Upvotes: 6

Related Questions