Mark Fingerhuth
Mark Fingerhuth

Reputation: 143

Kedro - how to pass nested parameters directly to node

kedro recommends storing parameters in conf/base/parameters.yml. Let's assume it looks like this:

step_size: 1
model_params:
    learning_rate: 0.01
    test_data_ratio: 0.2
    num_train_steps: 10000

And now imagine I have some data_engineering pipeline whose nodes.py has a function that looks something like this:

def some_pipeline_step(num_train_steps):
    """
    Takes the parameter `num_train_steps` as argument.
    """
    pass

How would I go about and pass that nested parameters straight to this function in data_engineering/pipeline.py? I unsuccessfully tried:

from kedro.pipeline import Pipeline, node

from .nodes import split_data


def create_pipeline(**kwargs):
    return Pipeline(
        [
            node(
                some_pipeline_step,
                ["params:model_params.num_train_steps"],
                dict(
                    train_x="train_x",
                    train_y="train_y",
                ),
            )
        ]
    )

I know that I could just pass all parameters into the function by using ['parameters'] or just pass all model_params parameters with ['params:model_params'] but it seems unelegant and I feel like there must be a way. Would appreciate any input!

Upvotes: 2

Views: 3605

Answers (2)

Tomasz Bartkowiak
Tomasz Bartkowiak

Reputation: 15008

As mentioned by Dmitry, kedro 0.16.0 introduced nested parameter values inside the node inputs which can be accessed via . operator:

node(func, "params:a.b", None)

whereas kedro 0.17.6 enabled overriding nested parameters with params in CLI, e.g.

kedro run --params="model.model_tuning.booster:gbtree"

Upvotes: 2

Dmitry Deryabin
Dmitry Deryabin

Reputation: 1578

(Disclaimer: I'm part of the Kedro team)

Thank you for your question. Current version of Kedro, unfortunately, does not support nested parameters. The interim solution would be to use top-level keys inside the node (as you already pointed out) or decorate your node function with some sort of a parameter filter, which is not elegant either.

Probably the most viable solution would be to customise your ProjectContext (in src/<package_name>/run.py) class by overwriting _get_feed_dict method as follows:

class ProjectContext(KedroContext):
    # ...


    def _get_feed_dict(self) -> Dict[str, Any]:
        """Get parameters and return the feed dictionary."""
        params = self.params
        feed_dict = {"parameters": params}

        def _add_param_to_feed_dict(param_name, param_value):
            """This recursively adds parameter paths to the `feed_dict`,
            whenever `param_value` is a dictionary itself, so that users can
            specify specific nested parameters in their node inputs.

            Example:

                >>> param_name = "a"
                >>> param_value = {"b": 1}
                >>> _add_param_to_feed_dict(param_name, param_value)
                >>> assert feed_dict["params:a"] == {"b": 1}
                >>> assert feed_dict["params:a.b"] == 1
            """
            key = "params:{}".format(param_name)
            feed_dict[key] = param_value

            if isinstance(param_value, dict):
                for key, val in param_value.items():
                    _add_param_to_feed_dict("{}.{}".format(param_name, key), val)

        for param_name, param_value in params.items():
            _add_param_to_feed_dict(param_name, param_value)

        return feed_dict

Please also note that this issue has already been addressed on develop and will become available in the next release. The fix uses the approach from the snippet above.

Upvotes: 3

Related Questions