Reputation: 879
I recently started using Metaflow for my hyperparameter searches. I'm using a foreach
for all my parameters as follows:
from metaflow import FlowSpec, step
@step
def start_hpo(self):
self.next(self.train_model, foreach='hpo_parameters')
@step
def train_model(self):
# Trains model...
This works as it starts the step train_model
as intended but unfortunately it wants to parallelise all steps at once. This causes my gpu / cpu to run out of memory instantly failing the step.
Is there a way to tell metaflow to do these steps linearly / one at a time instead or another workaround?
Thanks
Upvotes: 2
Views: 1551
Reputation: 19786
As mentioned, you can control this at a flow-level using the --max-workers
flag.
To permanently override the --max-workers
flag for a flow, here is a decorator. This decorator can also be used to override other Metaflow flags as well, such as --max-num-splits
.
def fix_cli_args(**kwargs: Dict[str, str]):
"""
Decorator to override Metaflow CLI arguments.
Usage:
@fix_cli_args(**{"--max-workers": "1", "--max-num-splits": "100"})
class InferencePipeline(FlowSpec): ...
Warnings:
If the argument is specified by the user, it will be overridden by the value specified in the decorator and a
warning will be raised.
"""
def decorator(pipeline):
def wrapper():
if "run" not in sys.argv and "resume" not in sys.argv:
# ignore this decorator if we are not running or resuming a flow
return pipeline()
for arg, val in kwargs.items():
if arg in sys.argv: # if arg was passed, override it
ind = sys.argv.index(arg)
logger.warning(f"`{arg}` arg was passed with value `{sys.argv[ind + 1]}`. However, this value will"
f"be overriden by @fix_cli_args with value {val}")
sys.argv[ind + 1] = val # replace the val
else: # otherwise, add (arg, val) to the call
sys.argv.extend([arg, val])
logger.info(f"Fixed CLI args for {kwargs.keys()}")
return pipeline()
return wrapper
return decorator
Upvotes: 0
Reputation: 191
@BBQuercus You can limit parallelization by using the --max-workers
flag.
Currently, we run no more than 16 tasks in parallel and you can override it as python myflow.py run --max-workers 32
for example.
Upvotes: 2