Nathan Wang
Nathan Wang

Reputation: 11

Pytorch running_mean, running_var and num_batches_tracked are updated during training, but I want to fix them

In pytorch, I want to use a pretrained model and train my model to add a delta to the model result, that is:

        ╭----- (pretrained model) ------ result ---╮
 input------------- (my model) --------- Δresult --+-- final_result

Here is what I did:

  1. Use load_state_dict to load pretrained model's parameters
  2. Set all pretrained model's parameters requires_grad = False
  3. Create my model and start training

But after training process, when I check result (the output of pretrained model), I find that it does not match the origin pretrained model output. I carefully compare the pretrained model's parameters, the only change are BatchNorm2d 's running_mean, running_var and num_batches_tracked (since I set all pretrained model's parameters requires_grad = False) , and when I change these three parameters back to origin ones, the result matches the origin pretrained model output.

I do not want any change in pretrained model. So is there any way to fix running_mean, running_var and num_batches_tracked?

Upvotes: 1

Views: 1843

Answers (1)

aretor
aretor

Reputation: 2569

I stumbled upon the same problem, so I adapted the context manager found in this repo as follows:

@contextlib.contextmanager
def _disable_tracking_bn_stats(self):
    def switch_attr():
        if not hasattr(self, 'running_stats_modules'):
            self.running_stats_modules = \
                [mod for n, mod in self.model.named_modules() if
                 hasattr(mod, 'track_running_stats')]

        for mod in self.running_stats_modules:
            mod.track_running_stats ^= True

    switch_attr()
    yield
    switch_attr()

As an alternative, I think you can obtain a similar result by calling eval on the BatchNorm modules:

for layer in net.modules():
    if isinstance(layer, BatchNorm2d):
        layer.eval()

though the first method is more principled.

Upvotes: 1

Related Questions