Nulle
Nulle

Reputation: 1329

Pytorch: "Model Weights not Changing"

Can someone help me understand why the weights are not updating?

    unet = Unet()
    optimizer = torch.optim.Adam(unet.parameters(), lr=0.001)
    loss_fn = torch.nn.MSELoss()
    input =  Variable(torch.randn(32, 1, 64, 64, 64 ), requires_grad=True)
    target = Variable(torch.randn(32, 1, 64, 64, 64), requires_grad=False)

    optimizer.zero_grad()
    y_pred = unet(input)
    y = target[: , : , 20:44, 20:44, 20:44]

    loss = loss_fn(y_pred, y)
    print(unet.conv1.weight.data[0][0]) # weights of the first layer in the unet
    loss.backward()
    optimizer.step()
    print(unet.conv1.weight.data[0][0]) # weights havent changed

The model is defined like:

class Unet(nn.Module):

def __init__(self):
  super(Unet, self).__init__()

  # Down hill1
  self.conv1 = nn.Conv3d(1, 2, kernel_size=3,  stride=1)
  self.conv2 = nn.Conv3d(2, 2, kernel_size=3,  stride=1)

  # Down hill2
  self.conv3 = nn.Conv3d(2, 4, kernel_size=3,  stride=1)
  self.conv4 = nn.Conv3d(4, 4, kernel_size=3,  stride=1)

  #bottom
  self.convbottom1 = nn.Conv3d(4, 8, kernel_size=3,  stride=1)
  self.convbottom2 = nn.Conv3d(8, 8, kernel_size=3,  stride=1)

  #up hill1
  self.upConv0 = nn.Conv3d(8, 4, kernel_size=3,  stride=1)
  self.upConv1 = nn.Conv3d(4, 4, kernel_size=3,  stride=1)
  self.upConv2 = nn.Conv3d(4, 2, kernel_size=3,  stride=1)

  #up hill2
  self.upConv3 = nn.Conv3d(2, 2, kernel_size=3, stride=1)
  self.upConv4 = nn.Conv3d(2, 1, kernel_size=1, stride=1)

  self.mp = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
  # some more irrelevant properties...

The forward function looks like:

def forward(self, input):
    # Use U-net Theory to Update the filters.
    # Example Approach...
    input = F.relu(self.conv1(input))
    input = F.relu(self.conv2(input))

    input = self.mp(input)

    input = F.relu(self.conv3(input))
    input = F.relu(self.conv4(input))

    input = self.mp(input)

    input = F.relu(self.convbottom1(input))
    input = F.relu(self.convbottom2(input))

    input = F.interpolate(input, scale_factor=2, mode='trilinear')

    input = F.relu(self.upConv0(input))
    input = F.relu(self.upConv1(input))

    input = F.interpolate(input, scale_factor=2, mode='trilinear')


    input = F.relu(self.upConv2(input))
    input = F.relu(self.upConv3(input))

    input = F.relu(self.upConv4(input))

    return input

I have followed the approach of any example and documentation i could find and it is beyound me why that doesn't work?

I can figure out as much that y_pred.grad after the backward call is none which it shouldn't be. If we have no gradient then ofcourse the optimizer can't change the weights in any direction but why is there no gradient?

Upvotes: 1

Views: 2986

Answers (2)

Nulle
Nulle

Reputation: 1329

I identified this problem to be of "The Dying ReLu Problem" Due to the data being Hounsfield units and Pytorch uniform distribution of initial weights meant that many neurons would start out in ReLu's zero region leaving them paralyzed and dependable on other neurons to produce a gradient that could pull them out of the zero region. This unlikely to happen as training progresses all neurons gets pushed into ReLu's zero region.

There are several solutions to this problem. You can use Leaky_relu or other activation functions that do not have a zero region.

You can also normalize the input data using Batch Normalization and initialize the weights to only be of the positive kind.

Solution number two is probably the most optimal solution since both will solve the problem but leaky_relu will prolong training whereas Batch normalization will do the opposite and increase accuracy. On the other hand, Leaky_relu is an easy fix whereas the other solution requires a little extra work.

For Hounsfield data, one could also add a constant of 1000 to the input eliminating negative units from the data. This still requires a different weight initialization than Pytorch's standard initialization.

Upvotes: 3

artona
artona

Reputation: 1272

I do not think that weights should be printed with command you use. Try print(unet.conv1.state_dict()["weight"]) instead of print(unet.conv1.weight.data[0][0]).

Upvotes: 0

Related Questions