dunky11
dunky11

Reputation: 99

How is transformers loss calculated for blank token predictions?

I'm currently trying to implement a transformer and have trouble understanding its loss calculation.

My encoders input looks for batch_size=1 and max_sentence_length=8 like:

[[Das, Wetter, ist, gut, <blank>, <blank>, <blank>, <blank>]]

My decoders input looks like (german to english):

[[<start>, The, weather, is, good, <end>, <blank>, <blank>]]

Let's say my transformer predicted those class probabilities (only showing the word for the class with the highest class probability):

[[The, good, is, weather, <end>, <blank>, <blank>, <blank>]]

Now I calculate the loss using:

loss = categorical_crossentropy(
   [[The, good, is, weather, <end>, <blank>, <blank>, <blank>]],
   [[The, weather, is, good, <end>, <blank>, <blank>, <blank>]]
)

Is this the correct way to calculate the loss? My transformer always predicts the blank token for the next word and I thought that's because I have a mistake in my loss calculation and have to do something with the blank tokens before calculating the loss.

Upvotes: 4

Views: 1955

Answers (2)

Arun prakash
Arun prakash

Reputation: 11

In case of using frameworks like PyTorch, we can set

ignore_index=0

while computing the cross entropy loss using torch.nn.CrossEntropyLoss or torch.nn.functional.cross_entropy. Here, I assumed that the index for the pad token is 0.

Upvotes: 1

Jindřich
Jindřich

Reputation: 11213

You need to mask out the padding. (What you call is <blank> is more often called <pad>.)

  • Create a mask saying where the valid tokens are (pseudocode: mask = target != '<pad>')

  • When computing the categorical cross-entropy, do not automatically reduce the loss and keep the value.

  • Multiply the loss values with the mask, i.e., positions corresponding to the <blank> tokens get zero out and sum the losses at the valid positions. (pseudocode: loss_sum = (loss * mask).sum())

  • Divide the loss_sum by the number of valid position, i.e., the sum of the mask (pseudocode: loss = loss_sum / mask.sum())

Upvotes: 5

Related Questions