Reputation: 2411
I have a simple model:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(3, 10)
self.fc2 = nn.Linear(10, 30)
self.fc3 = nn.Linear(30, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.tanh(self.fc3(x))
return x
net = Model()
How can I keep the weights to always be between a certain value (eg -1,1)?
I tried the following:
self.fc1 = torch.tanh(nn.Linear(3, 10))
Which I'm not entirely sure that will always keep them in that value (even if the gradient update is trying to push them farther).
But got the following error:
TypeError: tanh(): argument 'input' (position 1) must be Tensor, not Linear
Upvotes: 0
Views: 2451
Reputation: 1656
According to the discuss.pytorch you can create extra class to clip weights between a given range. Link to the discussion.
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(3, 10)
self.fc2 = nn.Linear(10, 30)
self.fc3 = nn.Linear(30, 2)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.tanh(self.fc3(x))
return x
You should add weight clipper:
class WeightClipper(object):
def __call__(self, module):
# filter the variables to get the ones you want
if hasattr(module, 'weight'):
w = module.weight.data
w = w.clamp(-1,1)
module.weight.data = w
model = Model()
clipper = WeightClipper()
model.apply(clipper)
Upvotes: 2