muhleeshe
muhleeshe

Reputation: 61

Output logits from T5 model for text generation purposes

I am using the T5 model found on Hugging Face for text summarization. How can I output the logits of the T5 model directly given a text input for generation purposes (not training)?

I want to generate the outputs token by token so that I can calculate the entropy of each output token, respectively. It does not seem like the .generate() method will work for this.

I effectively want to create my own generate function but I need to obtain the logits of the model to be able to do this.

Upvotes: 2

Views: 4628

Answers (1)

Edwin Cheong
Edwin Cheong

Reputation: 989

You can use the forward function to get your logits, and apply argmax as such:

from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch.nn.functional as F

tokenizer = AutoTokenizer.from_pretrained("t5-small")
model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")

input_ids = tokenizer("test here", padding="longest",
    max_length=128
    truncation=True,
    return_tensors="pt"
)

logits = model(**input_ids).logits

preds = F.softmax(logits, dim=-1).argmax(dim=-1)
y = tokenizer.batch_decode(sequences=preds, skip_special_tokens=True)

You may check the original source here: Forward outputs on multiple sequences is wrong

Upvotes: 1

Related Questions