Penguin
Penguin

Reputation: 2411

Forcing NN weights to always be in a certain range

I have a simple model:

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 10)
        self.fc2 = nn.Linear(10, 30)
        self.fc3 = nn.Linear(30, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x)) 
        return x

    net = Model()

How can I keep the weights to always be between a certain value (eg -1,1)?

I tried the following:

self.fc1 = torch.tanh(nn.Linear(3, 10))

Which I'm not entirely sure that will always keep them in that value (even if the gradient update is trying to push them farther).

But got the following error:

TypeError: tanh(): argument 'input' (position 1) must be Tensor, not Linear

Upvotes: 0

Views: 2451

Answers (1)

yakhyo
yakhyo

Reputation: 1656

According to the discuss.pytorch you can create extra class to clip weights between a given range. Link to the discussion.

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(3, 10)
        self.fc2 = nn.Linear(10, 30)
        self.fc3 = nn.Linear(30, 2)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.tanh(self.fc3(x)) 
        return x

You should add weight clipper:

class WeightClipper(object):
    
    def __call__(self, module):
        # filter the variables to get the ones you want
        if hasattr(module, 'weight'):
            w = module.weight.data
            w = w.clamp(-1,1)
            module.weight.data = w


model = Model()
clipper = WeightClipper()
model.apply(clipper)

Upvotes: 2

Related Questions