BBQuercus
BBQuercus

Reputation: 879

Stop Metaflow from parallelising foreach steps

I recently started using Metaflow for my hyperparameter searches. I'm using a foreach for all my parameters as follows:

from metaflow import FlowSpec, step

@step
def start_hpo(self):
    self.next(self.train_model, foreach='hpo_parameters')

@step
def train_model(self):
    # Trains model...

This works as it starts the step train_model as intended but unfortunately it wants to parallelise all steps at once. This causes my gpu / cpu to run out of memory instantly failing the step.

Is there a way to tell metaflow to do these steps linearly / one at a time instead or another workaround?

Thanks

Upvotes: 2

Views: 1551

Answers (2)

crypdick
crypdick

Reputation: 19786

As mentioned, you can control this at a flow-level using the --max-workers flag.

To permanently override the --max-workers flag for a flow, here is a decorator. This decorator can also be used to override other Metaflow flags as well, such as --max-num-splits.

def fix_cli_args(**kwargs: Dict[str, str]):
    """
    Decorator to override Metaflow CLI arguments.

    Usage:
        @fix_cli_args(**{"--max-workers": "1", "--max-num-splits": "100"})
        class InferencePipeline(FlowSpec): ...

    Warnings:
        If the argument is specified by the user, it will be overridden by the value specified in the decorator and a
        warning will be raised.
    """

    def decorator(pipeline):
        def wrapper():
            if "run" not in sys.argv and "resume" not in sys.argv:
                # ignore this decorator if we are not running or resuming a flow
                return pipeline()
            for arg, val in kwargs.items():
                if arg in sys.argv:  # if arg was passed, override it
                    ind = sys.argv.index(arg)
                    logger.warning(f"`{arg}` arg was passed with value `{sys.argv[ind + 1]}`. However, this value will"
                                   f"be overriden by @fix_cli_args with value {val}")
                    sys.argv[ind + 1] = val  # replace the val
                else:  # otherwise, add (arg, val) to the call
                    sys.argv.extend([arg, val])
            logger.info(f"Fixed CLI args for {kwargs.keys()}")
            return pipeline()

        return wrapper

    return decorator

Upvotes: 0

Savin
Savin

Reputation: 191

@BBQuercus You can limit parallelization by using the --max-workers flag.

Currently, we run no more than 16 tasks in parallel and you can override it as python myflow.py run --max-workers 32 for example.

Upvotes: 2

Related Questions