Daniel
Daniel

Reputation: 12026

Overwriting vs mutating pytorch weights

I'm trying to understand why I cannot directly overwrite the weights of a torch layer. Consider the following example:

import torch
from torch import nn

net = nn.Linear(3, 1)
weights = torch.zeros(1,3)

# Overwriting does not work
net.state_dict()["weight"] = weights  # nothing happens
print(f"{net.state_dict()['weight']=}")

# But mutating does work
net.state_dict()["weight"][0] = weights  # indexing works
print(f"{net.state_dict()['weight']=}")

#########
# output
: net.state_dict()['weight']=tensor([[ 0.5464, -0.4110, -0.1063]])
: net.state_dict()['weight']=tensor([[0., 0., 0.]])

I'm confused since state_dict()["weight"] is just a torch tensor, so I feel I'm missing something really obvious here.

Upvotes: 2

Views: 503

Answers (2)

ihdv
ihdv

Reputation: 2287

This is because net.state_dict() first creates a collections.OrderedDict object, then stores the weight tensor(s) of this module to it, and returns the dict:

state_dict = net.state_dict()
print(type(state_dict))    # <class 'collections.OrderedDict'>

When you "overwrite" (it's in fact not an overwrite; it's assignment in python) this ordered dict, you reassign an int 0 to the key 'weights' of this ordered dict. The data in that tensor is not modified, it's just not referred to by the ordered dict.

When you check whether the tensor is modified by:

print(f"{net.state_dict()['weight']}")

a new ordered dict different from the one you have modified is created, so you see the unchanged tensor.

However, when you use indexing like this:

net.state_dict()["weight"][0] = weights  # indexing works

then it's not assignment to the ordered dict anymore. Instead, the __setitem__ method of the tensor is called, which allows you to access and modify the underlying memory inplace. Other tensor APIs such as copy_ can also achieve desired results.

A clear explanation on the difference of a = b and a[:] = b when a is a tensor/array can be found here: https://stackoverflow.com/a/68978622/11790637

Upvotes: 1

John Stud
John Stud

Reputation: 1779

I dont have torch installed right now,but try something like this from some saved code I have. I believe you need to make deep copys, like so

def zero_injection(initial_weights, trained_weights, mask):
    ''' zeros all weights and then injects in masked selection '''
    # copy the weights
    initial_weights_copy = copy.deepcopy(initial_weights.state_dict())
    trained_weights_copy = copy.deepcopy(trained_weights.state_dict())

    # set all the values to zero
    for key, value in initial_weights_copy.items():
        initial_weights_copy[key][initial_weights_copy[key] < 0] = 0
        initial_weights_copy[key][initial_weights_copy[key] > 0] = 0

    state_dict = {}
    # for each key
    for key, value in initial_weights_copy.items():
        # add the key
        state_dict[key] = []
        # if False, replace initial value with trained value
        state_dict[key] = initial_weights_copy[key].cuda().where(mask[key].cuda(), trained_weights_copy[key].cuda())

    return state_dict

Upvotes: 0

Related Questions