Reputation: 21
I have to define a huber loss function which is this:
This is my code
def huber(a, b):
res = (((a-b)[abs(a-b) < 1]) ** 2 / 2).sum()
res += ((abs(a-b)[abs(a-b) >= 1]) - 0.5).sum()
res = res / torch.numel(a)
return res
'''
yet, it is not working properly. Do you have any idea what is wrong?
Upvotes: 0
Views: 2090
Reputation: 499
Huber loss function already exists in PyTorch under the name of torch.nn.SmoothL1Loss
.
Follow this link https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html for more!
Upvotes: 1