programmingWolf
programmingWolf

Reputation: 43

Pytorch: how do I make sure the model output is not 0 or negative?

I need to calculate loss on my model. The loss requires the logarithm of the output. This is for an actor critic model for those who want to know. I use a network that uses relu and softmax to make sure the values are not getting to high or that they are negative. But they are sometimes 0. This is not good since I cannot take the log of that.

What can I do to avoid this?

I tried using a custom relu function but for some reason it does not work.

I tried also Increasing the value in cases that it is 0 by 0.01 but then I get an error that there was a local change.

The loss function looks like this. Where P is the output of the model, eta and value constant are some unimportant values. And a[t] is the action at time t. This is not important as well. The important part is that the P output should not be 0.0.

x = self.eta*P*torch.log(P)
theta_loss += -value_constant*torch.log(P[a[t]])+torch.sum(x)

This is the relu function

class MyReLU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, inp):
        ctx.save_for_backward(inp)
        # out = torch.zeros_like(inp).cuda()
        # out[inp > 0.01] = inp
        return torch.where(inp < 0.01, 0.01, inp)

    @staticmethod
    def backward(ctx, grad_output):
        inp, = ctx.saved_tensors
        # grad_input = grad_output.clone()
        # grad_input[inp < 0.01] = 0
        grad = torch.where(inp <= 0.01,0.0,1)
        return grad_output * grad

Upvotes: 1

Views: 65

Answers (1)

MuhammedYunus
MuhammedYunus

Reputation: 5105

The approach below uses nn.Softplus and .clamp to map an arbitrarily-valued tensor to a positive one, before taking the log.

P_nonneg = nn.Softplus(P_original)       #constrain it to be >= 0
P_pos = torch.clamp(P_nonneg, min=1e-10) #prevent 0, so it's always +ve
loss = torch.log(P_pos)                  #safely take log

You can set min= depending on how far away from 0 you want to constrain P immediately before taking the log. I've set it to a small number, so it can come close to 0 and potentially render a huge loss, but it'll never hit 0 exactly. If such a large loss causes issues, you can turn min= up to smooth things out, though it makes the loss a bit less sensitive to smaller values.

Upvotes: 0

Related Questions