Reputation: 929
I have a question about SageMaker and Hydra.
TL;DR Is there a way to pass arguments from SageMaker estimator to a Hydra script? Currently it passes parameters in a very strict way.
Full Question I use Hydra in order to pass configs to my training script. I have many configs and it works good for my. For example, if I want to use a specific optimizer, I do:
python train.py optimizer=adam
This is my training script, for instance:
@hydra.main(version_base=None, config_path="configs/", config_name="config")
def train(config: DictConfig):
logging.info(f"Instantiating dataset <{config.dataset._target_}>")
train_ds, val_ds = hydra.utils.call(config.dataset)
logging.info(f"Instantiating model <{config.model._target_}>")
model = hydra.utils.call(config.model)
logging.info(f"Instantiating optimizer <{config.optimizer._target_}>")
optimizer = hydra.utils.instantiate(config.optimizer)
logging.info(f"Instantiating loss <{config.loss._target_}>")
loss = hydra.utils.instantiate(config.loss)
callbacks = []
if "callbacks" in config:
for _, cb_conf in config.callbacks.items():
if "_target_" in cb_conf:
logging.info(f"Instantiating callback <{cb_conf._target_}>")
callbacks.append(hydra.utils.instantiate(cb_conf))
metrics = []
if "metrics" in config:
for _, metric_conf in config.metrics.items():
if "_target_" in metric_conf:
logging.info(f"Instantiating metric <{metric_conf._target_}>")
metrics.append(hydra.utils.instantiate(metric_conf))
model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
model.fit(
train_ds,
validation_data=val_ds,
epochs=config.epochs,
callbacks=callbacks,
)
if __name__ == "__main__":
train()
And I have a relevant optimizer/adam.yaml
file.
Now, I started using SageMaker to run my experiments in the cloud and I noticed a problem.
It doesn't support the hydra syntax (+optimizer=sgd
), stuff like that.
Is there a way to make it play nicely with Hydra syntax? If not, do you have a suggestion for refactoring my training code so that it would work nicely with Hydra/OmegaConf?
I saw there is a similar question in SageMaker issues page, but it doesn't have any replies: https://github.com/aws/sagemaker-python-sdk/issues/1837
Upvotes: 2
Views: 773
Reputation: 7639
You might consider using Hydra's Compose API. This way you can preprocess the command-line args passed to your python program before they are ingested by Hydra.
Upvotes: 0
Reputation: 1314
You could look at passing the arguments as ENVs and ingesting them in your training script?
You can pass a dict containing the ENVs: https://sagemaker.readthedocs.io/en/stable/api/training/estimators.html#estimators
Upvotes: 0