Li haonan
Li haonan

Reputation: 628

Pytorch: CNN don't learn anything after torch.cat()?

I try to concatenate Variable in the network with code like this

    x = self.layer1(x)
    x = self.layer2(x)
    x = self.layer3(x)
    x = self.layer4(x)

    x = x.view(x.size(0), -1)
    x= torch.cat((x,angle),1) # from here I concat it.
    x = self.dropout1(self.relu1(self.bn1(self.fc1(x))))
    x = self.dropout2(self.relu2(self.bn2(self.fc2(x))))
    x = self.fc3(x)

And then I find my network learn nothing and give acc always around 50%. So I print param.grad and as I expected, they are all nan. Does anyone encounter this thing before?

I ran the code without concatenation before and it works out well. So I suppose this is where the rub is and the system doesn't throw any error or exception. if any other backup info is needed, please let me know.

Thank you.

Upvotes: 1

Views: 447

Answers (1)

Egor Lakomkin
Egor Lakomkin

Reputation: 1434

Probably the error is somewhere outside of the code that you provided. Try to check if there are nan's in your input and check if the loss function is not resulting in nan.

Upvotes: 1

Related Questions