sachinruk
sachinruk

Reputation: 9869

Sequence to Sequence Loss

I'm trying to figure out how sequence to sequence loss is calculated. I am using the huggingface transformers library in this case, but this might actually be relevant to other DL libraries.

So to get the required data we can do:

from transformers import EncoderDecoderModel, BertTokenizer
import torch
import torch.nn.functional as F
torch.manual_seed(42)

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
MAX_LEN = 128
tokenize = lambda x: tokenizer(x, max_length=MAX_LEN, truncation=True, padding=True, return_tensors="pt")

model = EncoderDecoderModel.from_encoder_decoder_pretrained('bert-base-uncased', 'bert-base-uncased') # initialize Bert2Bert from pre-trained checkpoints
input_seq = ["Hello, my dog is cute", "my cat cute"]
output_seq = ["Yes it is", "ok"]
input_tokens = tokenize(input_seq)
output_tokens = tokenize(output_seq)

outputs = model(
    input_ids=input_tokens["input_ids"], 
    attention_mask=input_tokens["attention_mask"],
    decoder_input_ids=output_tokens["input_ids"], 
    decoder_attention_mask=output_tokens["attention_mask"],
    labels=output_tokens["input_ids"], 
    return_dict=True)

idx = output_tokens["input_ids"]
logits = F.log_softmax(outputs["logits"], dim=-1)
mask = output_tokens["attention_mask"]

Edit 1

Thanks to @cronoik I was able to replicate the loss calculated by huggingface as being:

output_logits = logits[:,:-1,:]
output_mask = mask[:,:-1]
label_tokens = output_tokens["input_ids"][:, 1:].unsqueeze(-1)
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
huggingface_loss = -select_logits.mean()

However, since the last two tokens of the second input is just padding, shouldn't we calculate the loss to be:

seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss = -seq_loss.mean()

^This takes into account the length of the sequence of each row of outputs, and the padding by masking it out. Think this is especially useful when we have batches of varying length outputs.

Upvotes: 6

Views: 1869

Answers (2)

XiaoxiaShirley Wu
XiaoxiaShirley Wu

Reputation: 11

thanks for sharing. However, the new version of transformers as of today actually does not "shift" anymore. The following is not needed.

#shift things 
output_logits = logits[:,:-1,:]
label_tokens = idx[:, 1:].unsqueeze(-1) 
output_mask = mask[:,1:

Upvotes: 1

sachinruk
sachinruk

Reputation: 9869

ok I found out where I was making the mistakes. This is all thanks to this thread in the HuggingFace forum.

  1. The output labels need to have -100 for the masked version. The transoformers library does not do it for you.
  2. One silly mistake I made was with the mask. It should have been output_mask = mask[:, 1:] instead of :-1.

1. Using Model

We need to set the masks of output to -100. It is important to use clone as shown below:

labels = output_tokens["input_ids"].clone()
labels[output_tokens["attention_mask"]==0] = -100

outputs = model(
    input_ids=input_tokens["input_ids"], 
    attention_mask=input_tokens["attention_mask"],
    decoder_input_ids=output_tokens["input_ids"], 
    decoder_attention_mask=output_tokens["attention_mask"],
    labels=labels, 
    return_dict=True)

2. Calculating Loss

So the final way to replicate it is as follows:

idx = output_tokens["input_ids"]
logits = F.log_softmax(outputs["logits"], dim=-1)
mask = output_tokens["attention_mask"]

# shift things
output_logits = logits[:,:-1,:]
label_tokens = idx[:, 1:].unsqueeze(-1)
output_mask = mask[:,1:]

# gather the logits and mask
select_logits = torch.gather(output_logits, -1, label_tokens).squeeze()
-select_logits[output_mask==1].mean(), outputs["loss"]

The above however ignores the fact that this comes from two different lines. So an alternate way of calculating loss could be:

seq_loss = (select_logits * output_mask).sum(dim=-1, keepdims=True) / output_mask.sum(dim=-1, keepdims=True)
seq_loss.mean()

Upvotes: 1

Related Questions