Reputation: 10306
I have a pretrained model that I'm using in conjunction with a model being trained. I want the pretrained model to always be in eval mode, but the other model will be moving back and forth between eval and train mode. I'd still like the pretrained model to be a submodule of the other one, though (e.g. so that all parameters stay on the same device). Is there a way to do this? Here's a minimal example:
from torch import nn
class FixedModule(nn.Module):
pass
class TrainableModule(nn.Module):
def __init__(self, fixed_module):
super().__init__()
self.fixed_module = fixed_module
fixed = FixedModule().eval()
assert not fixed.training
trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training
trainable.train()
assert trainable.fixed_module.training # I'd like this to give an error
I know I can work around this by, e.g., always doing
trainable.train()
trainable.fixed_module.eval()
but that's error-prone and doesn't work well with existing code.
Upvotes: 2
Views: 2472
Reputation: 10306
You can override train
in FixedModule
to prevent it from changing modes. Note that eval
just calls train(False)
, so you don't need to override that too. But calling FixedModule.eval
won't do anything now, so you have to set training = False
in init.
from torch import nn
class FixedModule(nn.Module):
def __init__(self):
super().__init__()
self.training = False
# add any other nn.Module attributes here before calling self.children
# you could override `train` in each child too if you really wanted,
# but that seems like overkill unless there are external references
# to any submodules of FixedModule
for module in self.children():
module.eval()
def train(self, mode):
return self
class TrainableModule(nn.Module):
def __init__(self, fixed_module):
super().__init__()
self.fixed_module = fixed_module
fixed = FixedModule().eval()
assert not fixed.training
trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training
trainable.train()
assert not trainable.fixed_module.training # passes
Upvotes: 0
Reputation: 4745
One solution is to override train
like this:
from torch import nn
class FixedModule(nn.Module):
pass
class TrainableModule(nn.Module):
def __init__(self, fixed_module):
super().__init__()
self.fixed_module = fixed_module
def train(self):
super().train()
self.fixed_module.eval()
fixed = FixedModule().eval()
assert not fixed.training
trainable = TrainableModule(fixed)
assert trainable.training and not trainable.fixed_module.training
trainable.train()
assert trainable.fixed_module.training # This gives an error now
Upvotes: 4