Reputation: 43
I need to calculate loss on my model. The loss requires the logarithm of the output. This is for an actor critic model for those who want to know. I use a network that uses relu and softmax to make sure the values are not getting to high or that they are negative. But they are sometimes 0. This is not good since I cannot take the log of that.
What can I do to avoid this?
I tried using a custom relu function but for some reason it does not work.
I tried also Increasing the value in cases that it is 0 by 0.01 but then I get an error that there was a local change.
The loss function looks like this. Where P is the output of the model, eta and value constant are some unimportant values. And a[t] is the action at time t. This is not important as well. The important part is that the P output should not be 0.0.
x = self.eta*P*torch.log(P)
theta_loss += -value_constant*torch.log(P[a[t]])+torch.sum(x)
This is the relu function
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
ctx.save_for_backward(inp)
# out = torch.zeros_like(inp).cuda()
# out[inp > 0.01] = inp
return torch.where(inp < 0.01, 0.01, inp)
@staticmethod
def backward(ctx, grad_output):
inp, = ctx.saved_tensors
# grad_input = grad_output.clone()
# grad_input[inp < 0.01] = 0
grad = torch.where(inp <= 0.01,0.0,1)
return grad_output * grad
Upvotes: 1
Views: 65
Reputation: 5105
The approach below uses nn.Softplus
and .clamp
to map an arbitrarily-valued tensor to a positive one, before taking the log.
P_nonneg = nn.Softplus(P_original) #constrain it to be >= 0
P_pos = torch.clamp(P_nonneg, min=1e-10) #prevent 0, so it's always +ve
loss = torch.log(P_pos) #safely take log
You can set min=
depending on how far away from 0 you want to constrain P
immediately before taking the log. I've set it to a small number, so it can come close to 0 and potentially render a huge loss, but it'll never hit 0 exactly. If such a large loss causes issues, you can turn min=
up to smooth things out, though it makes the loss a bit less sensitive to smaller values.
Upvotes: 0