Reputation: 9668
I trained a T5 transformer using simpletransformers
library.
Here is a code to get the predictions:
pred_values = model.predict(input_values)
However, it just returns the top or greedy prediction, how can I get 10 top results?
Upvotes: 3
Views: 5094
Reputation: 9668
The required parameter is num_return_sequences
, which shows the number of samples to generate. However, you should also set a number for beam search if you want to use a beam search algorithm.
model_args = T5Args()
model_args.num_beams = 5
model_args.num_return_sequences = 2
Alternatively, you can use top_k
or top_p
to generate and select among top samples, in these cases, you must set do_sample
to True
. For more information about the parameters refer to [1] and [2], which is a detailed explanation.
model_args = T5Args()
model_args.do_sample = True
model_args.top_p = 0.9
model_args.num_return_sequences = 2
[1] https://simpletransformers.ai/docs/t5-model/
[2] https://huggingface.co/blog/how-to-generate
Upvotes: 6