Reputation: 4162
I am trying to implement U^2 Net for Salient Object detection. Since this code is not optimised for training, following this official documentation for AMP, I have made some changes to the original code in my fork to check the effects.
I have used the code exactly and when you run my version of training code on colab as :
! git clone https://github.com/deshwalmahesh/U-2-Net
%cd ./U-2-Net/
!python u2net_train.py
It'll throw you some error. The whole stack is posted in the end. I dug up and found that it is due to the custom loss function as muti_bce_loss_fusion
which the authors have used as:
bce_loss = nn.BCELoss(size_average=True)
def muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v):
loss0 = bce_loss(d0,labels_v)
loss1 = bce_loss(d1,labels_v)
loss2 = bce_loss(d2,labels_v)
loss3 = bce_loss(d3,labels_v)
loss4 = bce_loss(d4,labels_v)
loss5 = bce_loss(d5,labels_v)
loss6 = bce_loss(d6,labels_v)
loss = loss0 + loss1 + loss2 + loss3 + loss4 + loss5 + loss6
return loss0, loss
Also, in the last line i.e line 526 of the model definition, the model returns 7 sigmoid values which are passed to the loss function.
F.sigmoid(d0), F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)
Now what can be done to avoid this error?
Error trace
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:780: UserWarning: Note that order of the arguments: ceil_mode and return_indices will changeto match the args list in nn.MaxPool2d in a future release.
warnings.warn("Note that order of the arguments: ceil_mode and return_indices will change"
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:3704: UserWarning: nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.
warnings.warn("nn.functional.upsample is deprecated. Use nn.functional.interpolate instead.")
/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py:1944: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.
warnings.warn("nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.")
Traceback (most recent call last):
File "u2net_train.py", line 148, in <module>
loss2, loss = muti_bce_loss_fusion(d0, d1, d2, d3, d4, d5, d6, labels_v)
File "u2net_train.py", line 33, in muti_bce_loss_fusion
loss0 = bce_loss(d0,labels_v)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1110, in _call_impl
return forward_call(*input, **kwargs)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/loss.py", line 612, in forward
return F.binary_cross_entropy(input, target, weight=self.weight, reduction=self.reduction)
File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 3065, in binary_cross_entropy
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
RuntimeError: torch.nn.functional.binary_cross_entropy and torch.nn.BCELoss are unsafe to autocast.
Many models use a sigmoid layer right before the binary cross entropy layer.
In this case, combine the two layers using torch.nn.functional.binary_cross_entropy_with_logits
or torch.nn.BCEWithLogitsLoss. binary_cross_entropy_with_logits and BCEWithLogits are
safe to autocast.
Upvotes: 3
Views: 4847
Reputation: 4162
The main reason why it was due to unstable nature of Sigmoid + BCE
. Referring to documentation and torch
community, all I had to to do was to replace the models from F.sigmoid(d0)...
to d0.....
and then in turn replace nn.BCELoss(size_average=True)
with nn.BCEWithLogitsLoss(size_average=True)
. Now the model is running fine.
Upvotes: 4