James Arten
James Arten

Reputation: 666

Conditional intializations of parameters in hydra

I'm pretty new to hydra and was wondering if the following thing is was possible: I have the parameter num_atom_feats in the model section which I would like to make dependent on the feat_type parameter in the data section. In particular, if I have feat_type: type1 then I would like to have num_atom_feats:22. If instead, I initialize data with feat_type : type2 then I would like to have num_atom_feats:200

model:
  _target_: model.EmbNet_Lightning
  model_name: 'EmbNet'
  num_atom_feats: 22
  dim_target: 128
  loss: 'log_ratio'
  lr: 1e-3
  wd: 5e-6

data:
  _target_: data.DataModule
  feat_type: 'type1'
  batch_size: 64
  data_path: '.'

wandb:
  _target_:  pytorch_lightning.loggers.WandbLogger
  name: embnet_logger
  project: ''

trainer:
  max_epochs: 1000

Upvotes: 1

Views: 2275

Answers (1)

Jasha
Jasha

Reputation: 7639

You can achieve this using OmeagConf's custom resolver feature.

Here's an example showing how to register a custom resolver that computes model.num_atom_feat based on the value of data.feat_type:

from omegaconf import OmegaConf

yaml_data = """
model:
  _target_: model.EmbNet_Lightning
  model_name: 'EmbNet'
  num_atom_feats: ${compute_num_atom_feats:${data.feat_type}}

data:
  _target_: data.DataModule
  feat_type: 'type1'
"""

def compute_num_atom_feats(feat_type: str) -> int:
    if feat_type == "type1":
        return 22
    if feat_type == "type2":
        return 200
    assert False

OmegaConf.register_new_resolver("compute_num_atom_feats", compute_num_atom_feats)


cfg = OmegaConf.create(yaml_data)

assert cfg.data.feat_type == 'type1'
assert cfg.model.num_atom_feats == 22
cfg.data.feat_type = 'type2'
assert cfg.model.num_atom_feats == 200

I'd recommend reading through the docs of OmegaConf, which is the backend used by Hydra.

The compute_num_atom_feats function is invoked lazily when you access cfg.data.num_atom_feats in your python code.

When using custom resolvers with Hydra, you can call OmegaConf.register_new_resolver either before you invoke your @hydra.main-decorated function, or from within the @hydra.main-decorated function itself. The important thing is that you call OmegaConf.register_new_resolver before you access cfg.data.num_atom_feats.

Upvotes: 1

Related Questions