Reputation: 745
Let's say I wanted to multiply all parameters of a neural network in PyTorch (an instance of a class inheriting from torch.nn.Module
) by 0.9
. How would I do that?
Upvotes: 2
Views: 15341
Reputation: 745
Let net
be an instance of a neural network nn.Module
.
Then, to multiply all parameters by 0.9
:
state_dict = net.state_dict()
for name, param in state_dict.items():
# Transform the parameter as required.
transformed_param = param * 0.9
# Update the parameter.
param.copy_(transformed_param)
If you want to only update weights instead of every parameter:
state_dict = net.state_dict()
for name, param in state_dict.items():
# Don't update if this is not a weight.
if not "weight" in name:
continue
# Transform the parameter as required.
transformed_param = param * 0.9
# Update the parameter.
param.copy_(transformed_param)
Upvotes: 12
Reputation: 1541
A different way of achieving this is using tensor.parameters()
.
Initialize module:
>>> a = torch.nn.Linear(2, 2)
>>> a.state_dict()
OrderedDict([('weight',
tensor([[-0.1770, -0.2151],
[-0.6543, 0.6637]])),
('bias', tensor([-0.0524, 0.6807]))])
Change the parameters:
for p in a.parameters():
p.data *= 0
See the effect:
>>> a.state_dict()
OrderedDict([('weight',
tensor([[-0., -0.],
[-0., 0.]])),
('bias', tensor([-0., 0.]))])
Upvotes: 1