Nathan
Nathan

Reputation: 10306

How can I keep a PyTorch submodule in eval mode?

I have a pretrained model that I'm using in conjunction with a model being trained. I want the pretrained model to always be in eval mode, but the other model will be moving back and forth between eval and train mode. I'd still like the pretrained model to be a submodule of the other one, though (e.g. so that all parameters stay on the same device). Is there a way to do this? Here's a minimal example:

from torch import nn

class FixedModule(nn.Module):
    pass

class TrainableModule(nn.Module):
    def __init__(self, fixed_module):
        super().__init__()
        self.fixed_module = fixed_module

fixed = FixedModule().eval()
assert not fixed.training

trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training

trainable.train()
assert trainable.fixed_module.training  # I'd like this to give an error

I know I can work around this by, e.g., always doing

trainable.train()
trainable.fixed_module.eval()

but that's error-prone and doesn't work well with existing code.

Upvotes: 2

Views: 2472

Answers (2)

Nathan
Nathan

Reputation: 10306

You can override train in FixedModule to prevent it from changing modes. Note that eval just calls train(False), so you don't need to override that too. But calling FixedModule.eval won't do anything now, so you have to set training = False in init.

from torch import nn

class FixedModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.training = False

        # add any other nn.Module attributes here before calling self.children

        # you could override `train` in each child too if you really wanted,
        # but that seems like overkill unless there are external references
        # to any submodules of FixedModule
        for module in self.children():
            module.eval()

    def train(self, mode):
        return self

class TrainableModule(nn.Module):
    def __init__(self, fixed_module):
        super().__init__()
        self.fixed_module = fixed_module    

fixed = FixedModule().eval()
assert not fixed.training

trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training

trainable.train()
assert not trainable.fixed_module.training # passes

Upvotes: 0

adamconkey
adamconkey

Reputation: 4745

One solution is to override train like this:

from torch import nn

class FixedModule(nn.Module):
    pass

class TrainableModule(nn.Module):
    def __init__(self, fixed_module):
        super().__init__()
        self.fixed_module = fixed_module

    def train(self):
        super().train()
        self.fixed_module.eval()

fixed = FixedModule().eval()
assert not fixed.training

trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training

trainable.train()
assert trainable.fixed_module.training  # This gives an error now

Upvotes: 4

Related Questions