Emre
Emre

Reputation: 6226

How to use callable targets with dataclass arguments in Hydra?

Is it possible to define a target using Structured Configs to avoid redefining all the parameters?

def good(config: Config):
    pass

def bad(param1, param2):
    pass

@dataclass
class Config:
    param1
    param2
    _target_: Any = good
    # _target_: Any = bad
    # _target_: str = 'Config.also_good'

    def also_good(self):
        pass

What type annotation should I use for _target_ in case of a class, function, or method? When I used Any I got

omegaconf.errors.UnsupportedValueType: Value 'function' is not a supported primitive type
    full_key: _target_

Upvotes: 0

Views: 3929

Answers (1)

Jasha
Jasha

Reputation: 7639

The _target_ type should be str. Here is an example using the instantiate API with a structured config:

# example.py
from dataclasses import dataclass

from hydra.utils import instantiate


def trgt(arg1: int, arg2: float):
    print(f"trgt function: got {arg1}, {arg2}")
    return "foobar"


@dataclass
class Config:
    _target_: str = "__main__.trgt"  # dotpath describing location of callable
    arg1: int = 123
    arg2: float = 10.1


val = instantiate(Config)
print(f"Returned value was {val}.")

Running the script:

$ python example.py
trgt function: got 123, 10.1
Returned value was foobar.

Notes:

  • The _target_ field must be a string describing the "dotpath" used to look up a callable (__main__.trgt in the example above). Other typical dotpaths would be e.g. my_module.my_function, builtins.range numpy.random.randn, etc. The instantiate function uses this dotpath to find the callable object that will be instantiated (the trgt function, in this example).
  • The fields of the structured config (except for the _target_ field and other special reserved fields such as _recursive_, _convert_ and _args_) will be passed as keyword arguments to the looked-up callable. In this example, the function trgt is called with keyword arguments arg1=123 and arg2=10.1. This is equivalent to the python code trgt(arg1=123, arg2=10.1).
  • If you define a structured config (using @dataclass or @attr.s), you can specify a default value for each field. For example, the default value for field arg1 above is 123. Calling instantiate(config) will result in an exception if config has any fields whose values are missing. Typically, a structured config's missing values will be filled in during Hydra's config composition process.

References:

Upvotes: 4

Related Questions