Deshwal
Deshwal

Reputation: 4162

RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast

I am trying to implement U^2 Net for Salient Object detection. Since this code is not optimised for training, following this official documentation for AMP, I have made some changes to the original code in my fork to check the effects.

I have used the code exactly and when you run my version of training code on colab as :

! git clone https://github.com/deshwalmahesh/U-2-Net
%cd ./U-2-Net/
!python u2net_train.py

It'll throw you some error. The whole stack is posted in the end. I dug up and found that it is due to the custom loss function as muti_bce_loss_fusion which the authors have used as:

bce_loss = nn.BCELoss(size_average=True)

def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):

    loss0 = bce_loss(d0,labels_v)
    loss1 = bce_loss(d1,labels_v)
    loss2 = bce_loss(d2,labels_v)
    loss3 = bce_loss(d3,labels_v)
    loss4 = bce_loss(d4,labels_v)
    loss5 = bce_loss(d5,labels_v)
    loss6 = bce_loss(d6,labels_v)

    loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
    return loss0, loss

Also, in the last line i.e line 526 of the model definition, the model returns 7 sigmoid values which are passed to the loss function.

F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)

Now what can be done to avoid this error?

Error trace

/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
  warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3704: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
  warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
  warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
  File "u2net_train.py", line 148, in <module>
    loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
  File "u2net_train.py", line 33, in muti_bce_loss_fusion
    loss0 = bce_loss(d0,labels_v)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 612, in forward
    return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 3065, in binary_cross_entropy
    return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss.  binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.

Upvotes: 3

Views: 4847

Answers (1)

Deshwal
Deshwal

Reputation: 4162

The main reason why it was due to unstable nature of Sigmoid + BCE. Referring to documentation and torch community, all I had to to do was to replace the models from F.sigmoid(d0)... to d0..... and then in turn replace nn.BCELoss(size_average=True) with nn.BCEWithLogitsLoss(size_average=True). Now the model is running fine.

Upvotes: 4

Related Questions