Reputation: 63
I’m trying to set one element of weight to 1 and then hold it the same until the end of learning (prevent it from updating in the next epochs). I know I can set requires_grad = False
but I just want this process for one element not all of the elements.
Upvotes: 2
Views: 597
Reputation: 40708
You can attach a backward hook on your nn.Module
such that during backpropagation you can overwrite the element of interest to 0
. This makes sure its value never changes without preventing backpropagation of the gradient to the input.
The new API for backward hooks is nn.Module.register_full_backward_hook
. First construct a callback function that will be used as the layer hook:
def freeze_single(index):
def callback(module, grad_input, grad_output):
module.weight.grad.data[index] = 0
return callback
Then, we can attach this hook to any nn.Module
. For instance, here I've decided to freeze component [0, 1, 2, 1]
of the convolutional layer:
>>> conv = nn.Conv2d(3, 1, 3)
>>> conv.weight.data[0, 1, 2, 1] = 1
>>> conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
Everything is set correctly, let us try:
>>> x = torch.rand(1, 3, 10, 10, requires_grad=True)
>>> conv(x).mean().backward()
Here we can verify the gradient of component [0, 1, 2, 1]
is indeed equal to 0
:
>>> conv.weight.grad
tensor([[[[0.4954, 0.4776, 0.4639],
[0.5179, 0.4992, 0.4856],
[0.5271, 0.5219, 0.5124]],
[[0.5367, 0.5035, 0.5009],
[0.5703, 0.5390, 0.5207],
[0.5422, 0.0000, 0.5109]], # <-
[[0.4937, 0.5150, 0.5200],
[0.4817, 0.5070, 0.5241],
[0.5039, 0.5295, 0.5445]]]])
You can detach/reattach the hook anytime with:
>>> hook = conv.register_full_backward_hook(freeze_single((0, 1, 2, 1)))
>>> hook.remove()
Don't forget if you remove the hook, the value of that component will change when you update your weights. You will have to reset it to 1
if you so desire. Otherwise, you can implement a second hook - a register_forward_pre_hook
hook this time - to handle that.
Upvotes: 3
Reputation: 16460
Ivan already talked about using a backward hook to override the gradient of the desired element.
An alternative approach to do this without hooks would be to just override the desired parameter with value before the forward pass.
Let's say if you just want to override 0,0 element of a linear layer with 1, you would do
def forward(x)
model.weight[0,0] = 1
# usual forward pass
When you do backward pass, the element would be updated due to usual gradient updates, but on the next forward pass, it will be overridden again by 1 and will stay at that value throughout all the training computations. This can also be achieved with a hook before forward.
Upvotes: 1