gasoon
gasoon

Reputation: 865

How can I only update some specific tensors in network with pytorch?

For instance, I want to only update all cnn weights in Resnet in the first 10 epochs and freeze the others.
And from 11th epoch, I wanna change to update the whole model.
How can I achieve the goal?

Upvotes: 3

Views: 2336

Answers (2)

weiyixie
weiyixie

Reputation: 581

very straightforward, as PYTORCH recreate the computational graph on the fly.

for p in resnet.parameters():
    p.requires_grad = False # this will freeze the module from training suppose that resnet is one of your module

if you have multiple modules, simply loop over it. then after 10 epoch, you simply call

for p in network.parameters():
    p.requires_grad = True # suppose your whole network is the 'network' module

Upvotes: 0

Shai
Shai

Reputation: 114866

You can set the learning rate (and some other meta-parameters) per parameters group. You only need to group your parameters according to your needs.
For example, setting different learning rate for conv layers:

import torch
import itertools
from torch import nn

conv_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                             if isinstance(m, nn.Conv2d)])
other_params = itertools.chain.from_iterable([m.parameters() for m in model.children()
                                              if not isinstance(m, nn.Conv2d)]) 
optimizer = torch.optim.SGD([{'params': other_params},
                             {'params': conv_params, 'lr': 0}],  # set init lr to 0
                            lr=lr_for_model)

You can later access the optimizer param_groups and modify the learning rate.

See per-parameter options for more information.

Upvotes: 6

Related Questions