the-bass
the-bass

Reputation: 745

How can I update the parameters of a neural network in PyTorch?

Let's say I wanted to multiply all parameters of a neural network in PyTorch (an instance of a class inheriting from torch.nn.Module) by 0.9. How would I do that?

Upvotes: 2

Views: 15341

Answers (2)

the-bass
the-bass

Reputation: 745

Let net be an instance of a neural network nn.Module. Then, to multiply all parameters by 0.9:

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    param.copy_(transformed_param)

If you want to only update weights instead of every parameter:

state_dict = net.state_dict()

for name, param in state_dict.items():
    # Don't update if this is not a weight.
    if not "weight" in name:
        continue
    
    # Transform the parameter as required.
    transformed_param = param * 0.9

    # Update the parameter.
    param.copy_(transformed_param)

Upvotes: 12

Onno Eberhard
Onno Eberhard

Reputation: 1541

A different way of achieving this is using tensor.parameters().

Initialize module:

>>> a = torch.nn.Linear(2, 2)
>>> a.state_dict()
OrderedDict([('weight',
              tensor([[-0.1770, -0.2151],
                      [-0.6543,  0.6637]])),
             ('bias', tensor([-0.0524,  0.6807]))])

Change the parameters:

for p in a.parameters():
    p.data *= 0

See the effect:

>>> a.state_dict()
OrderedDict([('weight',
              tensor([[-0., -0.],
                      [-0., 0.]])),
             ('bias', tensor([-0., 0.]))])

Upvotes: 1

Related Questions