Reputation: 9869
I'm trying to figure out how sequence to sequence loss is calculated. I am using the huggingface transformers library in this case, but this might actually be relevant to other DL libraries.
So to get the required data we can do:
from transformers import EncoderDecoderModel, BertTokenizer
import torch
import torch.nn.functional as F
torch.manual_seed(42)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LEN = 128
tokenize = lambda x: tokenizer(x, max_length=MAX_LEN, truncation=True, padding=True, return_tensors="pt")
model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
input_seq = ["Hello, my dog is cute", "my cat cute"]
output_seq = ["Yes it is", "ok"]
input_tokens = tokenize(input_seq)
output_tokens = tokenize(output_seq)
outputs = model(
input_ids=input_tokens["input_ids"],
attention_mask=input_tokens["attention_mask"],
decoder_input_ids=output_tokens["input_ids"],
decoder_attention_mask=output_tokens["attention_mask"],
labels=output_tokens["input_ids"],
return_dict=True)
idx = output_tokens["input_ids"]
logits = F.log_softmax(outputs["logits"], dim=-1)
mask = output_tokens["attention_mask"]
Thanks to @cronoik I was able to replicate the loss calculated by huggingface as being:
output_logits = logits[:,:-1,:]
output_mask = mask[:,:-1]
label_tokens = output_tokens["input_ids"][:, 1:].unsqueeze(-1)
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
huggingface_loss = -select_logits.mean()
However, since the last two tokens of the second input is just padding, shouldn't we calculate the loss to be:
seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss = -seq_loss.mean()
^This takes into account the length of the sequence of each row of outputs, and the padding by masking it out. Think this is especially useful when we have batches of varying length outputs.
Upvotes: 6
Views: 1869
Reputation: 11
thanks for sharing. However, the new version of transformers as of today actually does not "shift" anymore. The following is not needed.
#shift things
output_logits = logits[:,:-1,:]
label_tokens = idx[:, 1:].unsqueeze(-1)
output_mask = mask[:,1:
Upvotes: 1
Reputation: 9869
ok I found out where I was making the mistakes. This is all thanks to this thread in the HuggingFace forum.
-100
for the masked version. The transoformers library does not do it for you.output_mask = mask[:, 1:]
instead of :-1
.We need to set the masks of output to -100. It is important to use clone as shown below:
labels = output_tokens["input_ids"].clone()
labels[output_tokens["attention_mask"]==0] = -100
outputs = model(
input_ids=input_tokens["input_ids"],
attention_mask=input_tokens["attention_mask"],
decoder_input_ids=output_tokens["input_ids"],
decoder_attention_mask=output_tokens["attention_mask"],
labels=labels,
return_dict=True)
So the final way to replicate it is as follows:
idx = output_tokens["input_ids"]
logits = F.log_softmax(outputs["logits"], dim=-1)
mask = output_tokens["attention_mask"]
# shift things
output_logits = logits[:,:-1,:]
label_tokens = idx[:, 1:].unsqueeze(-1)
output_mask = mask[:,1:]
# gather the logits and mask
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
-select_logits[output_mask==1].mean(), outputs["loss"]
The above however ignores the fact that this comes from two different lines. So an alternate way of calculating loss could be:
seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss.mean()
Upvotes: 1