Reputation: 3472
I have a neural network with the following structure:
class myNetwork(nn.Module):
def __init__(self):
super(myNetwork, self).__init__()
self.bigru = nn.GRU(input_size=2, hidden_size=100, batch_first=True, bidirectional=True)
self.fc1 = nn.Linear(200, 32)
torch.nn.init.xavier_uniform_(self.fc1.weight)
self.fc2 = nn.Linear(32, 2)
torch.nn.init.xavier_uniform_(self.fc2.weight)
I need to reinstate the model to an unlearned state by resetting the parameters of the neural network. I can do so for nn.Linear
layers by using the method below:
def reset_weights(self):
torch.nn.init.xavier_uniform_(self.fc1.weight)
torch.nn.init.xavier_uniform_(self.fc2.weight)
But, to reset the weight of the nn.GRU
layer, I could not find any such snippet.
My question is how does one reset the nn.GRU
layer? Any other way of resetting the network is also fine. Any help is appreciated.
Upvotes: 7
Views: 31320
Reputation: 5251
Here is the code with an example that runs:
def lp_norm(mdl: nn.Module, p: int = 2) -> Tensor:
lp_norms = [w.norm(p) for name, w in mdl.named_parameters()]
return sum(lp_norms)
def reset_all_weights(model: nn.Module) -> None:
"""
refs:
- https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/6
- https://stackoverflow.com/questions/63627997/reset-parameters-of-a-neural-network-in-pytorch
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def weight_reset(m: nn.Module):
# - check if the current module has reset_parameters & if it's callabed called it on m
reset_parameters = getattr(m, "reset_parameters", None)
if callable(reset_parameters):
m.reset_parameters()
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(fn=weight_reset)
def reset_all_linear_layer_weights(model: nn.Module) -> nn.Module:
"""
Resets all weights recursively for linear layers.
ref:
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def init_weights(m):
if type(m) == nn.Linear:
m.weight.fill_(1.0)
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(init_weights)
def reset_all_weights_with_specific_layer_type(model: nn.Module, modules_type2reset) -> nn.Module:
"""
Resets all weights recursively for linear layers.
ref:
- https://pytorch.org/docs/stable/generated/torch.nn.Module.html
"""
@torch.no_grad()
def init_weights(m):
if type(m) == modules_type2reset:
# if type(m) == torch.nn.BatchNorm2d:
# m.weight.fill_(1.0)
m.reset_parameters()
# Applies fn recursively to every submodule see: https://pytorch.org/docs/stable/generated/torch.nn.Module.html
model.apply(init_weights)
# -- tests
def reset_params_test():
import torchvision.models as models
from uutils.torch_uu import lp_norm
resnet18 = models.resnet18(pretrained=True)
resnet18_random = models.resnet18(pretrained=False)
print(f'{lp_norm(resnet18)=}')
print(f'{lp_norm(resnet18_random)=}')
print(f'{lp_norm(resnet18)=}')
reset_all_weights(resnet18)
print(f'{lp_norm(resnet18)=}')
if __name__ == '__main__':
reset_params_test()
print('Done! \a\n')
output:
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18_random)=tensor(668.3687, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(517.5472, grad_fn=<AddBackward0>)
lp_norm(resnet18)=tensor(476.0836, grad_fn=<AddBackward0>)
Done!
I am assuming this works because I calculated the norm twice for the pre-trained net and it was the same both times before calling reset.
Though I was unhappy it wasn't closer to the norm of the random net I must admit but I think this is good enough.
same: https://discuss.pytorch.org/t/how-to-re-set-alll-parameters-in-a-network/20819/11
Upvotes: 4
Reputation: 1
New to pytorch, I wonder if this could be a solution :)
Suppose Model inherents from torch.nn.module,
to reset it to zeros:
dic = Model.state_dict()
for k in dic:
dic[k] *= 0
Model.load_state_dict(dic)
del(dic)
to reset it randomly
dic = Model.state_dict()
for k in dic:
dic[k] = torch.randn(dic[k].size())
Model.load_state_dict(dic)
del(dic)
Upvotes: 0
Reputation: 7693
You can use reset_parameters
method on the layer. As given here
for layer in model.children():
if hasattr(layer, 'reset_parameters'):
layer.reset_parameters()
Or Another way would be saving the model first and then reload the module state. Using torch.save
and torch.load
see docs for more Or Saving and Loading Models
Upvotes: 13