yanwii
yanwii

Reputation: 170

How to deal with mini-batch loss in Pytorch?

I feed mini-batch data to model, and I just want to know how to deal with the loss. Could I accumulate the loss, then call the backward like:

    ...
    def neg_log_likelihood(self, sentences, tags, length):
        self.batch_size = sentences.size(0)

        logits = self.__get_lstm_features(sentences, length)
        real_path_score = torch.zeros(1)
        total_score = torch.zeros(1)
        if USE_GPU:
            real_path_score = real_path_score.cuda()
            total_score = total_score.cuda()

        for logit, tag, leng in zip(logits, tags, length):
            logit = logit[:leng]
            tag = tag[:leng]
            real_path_score += self.real_path_score(logit, tag)
            total_score += self.total_score(logit, tag)
        return total_score - real_path_score
    ...
loss = model.neg_log_likelihood(sentences, tags, length)
loss.backward()
optimizer.step()

I wonder that if the accumulation could lead to gradient explosion?

So, should I call the backward in loop:

for sentence, tag , leng in zip(sentences, tags, length):
    loss = model.neg_log_likelihood(sentence, tag, leng)
    loss.backward()
    optimizer.step()

Or, use the mean loss just like the reduce_mean in tensorflow

loss = reduce_mean(losses)
loss.backward()

Upvotes: 5

Views: 7302

Answers (2)

Lerner Zhang
Lerner Zhang

Reputation: 7130

We usually

  1. get the loss by the loss function
  2. (if necessary) manipulate the loss, for example do the class weighting and etc
  3. calculate the mean loss of the mini-batch
  4. calculate the gradients by the loss.backward()
  5. (if necessary) manipulate the gradients, for example, do the gradient clipping for some RNN models to avoid gradient explosion
  6. update the weights using the optimizer.step() function

So in your case, you can first get the mean loss of the mini-batch and then calculate the gradient using the loss.backward() function and then utilize the optimizer.step() function for the weight updating.

Upvotes: 1

scarecrow
scarecrow

Reputation: 6864

The loss has to be reduced by mean using the mini-batch size. If you look at the native PyTorch loss functions such as CrossEntropyLoss, there is a separate parameter reduction just for this and the default behaviour is to do mean on the mini-batch size.

Upvotes: 2

Related Questions