Reputation: 41
So from what I can tell by reading the docs, the default optimizer used in detectron2 is SGD with momentum. But I was wondering if there is a way to change the default optimizer to something like Adam for example.
The closest I was able to get to was by looking at cfg.SOLVER
attributes; specifically cfg.SOLVER.OPTIMIZER
. I even took a look at their source code to see if different optimizer options were available.
However, I was unable to change the optimizer type.
Any help will be greatly appreciated!
Upvotes: 3
Views: 2076
Reputation: 432
While it's best not to modify the source code, you may want to change torch.optim.SGD
in the following lines:
(line 129, detectron2/detectron2/solver/build.py
, inside build_optimizer
function)
def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build an optimizer from config.
"""
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
)
return maybe_add_gradient_clipping(cfg, torch.optim.SGD)(
params,
lr=cfg.SOLVER.BASE_LR,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
Upvotes: 0
Reputation: 79
You can create a subclass from DefaultTrainer
and add the build_optimizer
method. Check out the below code which uses Adam -
from detectron2.config import CfgNode
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping
class MyTrainer(DefaultTrainer):
@classmethod
def build_optimizer(cls, cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer:
"""
Build an optimizer from config.
"""
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
)
return maybe_add_gradient_clipping(cfg, torch.optim.Adam)(
params,
lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
)
Upvotes: 4