studyrivulet
studyrivulet

Reputation: 21

Defining Loss function in pytorch

I have to define a huber loss function which is this:enter image description here

This is my code

def huber(a, b): 
   res = (((a-b)[abs(a-b) < 1]) ** 2 / 2).sum()
   res += ((abs(a-b)[abs(a-b) >= 1]) - 0.5).sum()
   res = res / torch.numel(a)
   return res

'''

yet, it is not working properly. Do you have any idea what is wrong?

Upvotes: 0

Views: 2090

Answers (1)

Dimitri Sifoua
Dimitri Sifoua

Reputation: 499

Huber loss function already exists in PyTorch under the name of torch.nn.SmoothL1Loss.

Follow this link https://pytorch.org/docs/stable/generated/torch.nn.SmoothL1Loss.html for more!

Upvotes: 1

Related Questions