Nikoo_Ebrahimi
Nikoo_Ebrahimi

Reputation: 63

Prevent Updating for Specific Element of Convolutional Weight Matrix

I’m trying to set one element of weight to 1 and then hold it the same until the end of learning (prevent it from updating in the next epochs). I know I can set requires_grad = False but I just want this process for one element not all of the elements.

Upvotes: 2

Views: 597

Answers (2)

Ivan
Ivan

Reputation: 40708

You can attach a backward hook on your nn.Module such that during backpropagation you can overwrite the element of interest to 0. This makes sure its value never changes without preventing backpropagation of the gradient to the input.

The new API for backward hooks is nn.Module.register_full_backward_hook. First construct a callback function that will be used as the layer hook:

def freeze_single(index):
    def callback(module, grad_input, grad_output):
        module.weight.grad.data[index] = 0
    return callback

Then, we can attach this hook to any nn.Module. For instance, here I've decided to freeze component [0, 1, 2, 1] of the convolutional layer:

>>> conv = nn.Conv2d(3, 1, 3)
>>> conv.weight.data[0, 1, 2, 1] = 1

>>> conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))

Everything is set correctly, let us try:

>>> x = torch.rand(1, 3, 10, 10, requires_grad=True)
>>> conv(x).mean().backward()

Here we can verify the gradient of component [0, 1, 2, 1] is indeed equal to 0:

>>> conv.weight.grad
tensor([[[[0.4954, 0.4776, 0.4639],
          [0.5179, 0.4992, 0.4856],
          [0.5271, 0.5219, 0.5124]],

         [[0.5367, 0.5035, 0.5009],
          [0.5703, 0.5390, 0.5207],
          [0.5422, 0.0000, 0.5109]], # <-

         [[0.4937, 0.5150, 0.5200],
          [0.4817, 0.5070, 0.5241],
          [0.5039, 0.5295, 0.5445]]]])

You can detach/reattach the hook anytime with:

>>> hook = conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
>>> hook.remove()

Don't forget if you remove the hook, the value of that component will change when you update your weights. You will have to reset it to 1 if you so desire. Otherwise, you can implement a second hook - a register_forward_pre_hook hook this time - to handle that.

Upvotes: 3

Umang Gupta
Umang Gupta

Reputation: 16460

Ivan already talked about using a backward hook to override the gradient of the desired element.

An alternative approach to do this without hooks would be to just override the desired parameter with value before the forward pass.

Let's say if you just want to override 0,0 element of a linear layer with 1, you would do

def forward(x)
    model.weight[0,0] = 1
    # usual forward pass    

When you do backward pass, the element would be updated due to usual gradient updates, but on the next forward pass, it will be overridden again by 1 and will stay at that value throughout all the training computations. This can also be achieved with a hook before forward.

Upvotes: 1

Related Questions