Reputation: 61
I am using the T5 model found on Hugging Face for text summarization. How can I output the logits of the T5 model directly given a text input for generation purposes (not training)?
I want to generate the outputs token by token so that I can calculate the entropy of each output token, respectively. It does not seem like the .generate() method will work for this.
I effectively want to create my own generate function but I need to obtain the logits of the model to be able to do this.
Upvotes: 2
Views: 4628
Reputation: 989
You can use the forward function to get your logits, and apply argmax as such:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch.nn.functional as F
tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
input_ids = tokenizer("test here", padding="longest",
max_length=128
truncation=True,
return_tensors="pt"
)
logits = model(**input_ids).logits
preds = F.softmax(logits, dim=-1).argmax(dim=-1)
y = tokenizer.batch_decode(sequences=preds, skip_special_tokens=True)
You may check the original source here: Forward outputs on multiple sequences is wrong
Upvotes: 1