J.Tmr
J.Tmr

Reputation: 137

Why does my pytorch NN return a tensor of nan?

I have a quite simple neural network which takes a flattened 6x6 grid as input and should output the values of four actions to take on that grid, so a 1x4 tensor of values.

Sometimes after a few runs though for some reason I am getting a 1x4 tensor of nan

tensor([[nan, nan, nan, nan]], grad_fn=<ReluBackward0>)

My model looks like this with input dim being 36 and output dim being 4:

class Model(nn.Module):
    def __init__(self, input_dim, output_dim):
        # super relates to nn.Module so this initializes nn.Module
        super(Model, self).__init__()
        # Gridsize as input,
        # last layer needs 4 outputs because of 4 possible actions: left, right, up, down
        # output values are Q Values need activation function for those like argmax
        self.lin1 = nn.Linear(input_dim, 24)
        self.lin2 = nn.Linear(24, 24)
        self.lin3 = nn.Linear(24, output_dim)

    # function to feed the input through the net
    def forward(self, x):
        # rectified linear as activation function for the first two layers
        if isinstance(x, np.ndarray):
            x = torch.tensor(x, dtype=torch.float)

        activation1 = F.relu(self.lin1(x))
        activation2 = F.relu(self.lin2(activation1))
        output = F.relu(self.lin3(activation2))

        return output

The input was:

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 1.0000, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333,
         0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000,
         0.0000, 0.0000, 0.3333, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.6667]])

What are possible causes for a nan output and how can i fix those?

Upvotes: 4

Views: 20501

Answers (1)

Chris Holland
Chris Holland

Reputation: 579

nan values as outputs just mean that the training is instable which can have about every possible cause including all kinds of bugs in the code. If you think your code is correct you can try addressing the instability by lowering the learning rate or use gradient clipping.

Upvotes: 1

Related Questions