Reputation: 52
I have some code to train a RL agent in jax. The code runs fine.
To tune the hyperparameters I would like to use the optuna
plugin of hydra
since my project is based on the latter. To this end, I created the following config file:
defaults:
- base_conf
- env: mjx_reacher
- override hydra/launcher: joblib
- override hydra/sweeper: optuna
- _self_
hydra:
sweeper:
_target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
direction: maximize
n_jobs: 1
n_trials: 50
params:
unroll_length: range(1, 50)
num_minibatches: choice(8, 16, 32, 64)
num_updates_per_batch: choice(2, 4, 8, 16)
batch_size: choice(16, 64, 256)
learning_rate: interval(1e-4, 5e-3)
launcher:
backend: processes
prefer: processes
n_jobs: ${hydra.sweeper.n_jobs}
# Data gen / alg params
grad_clip_norm: 1.
num_timesteps: 500_000
seed: 123
num_evals: 10
unroll_length: 5
num_minibatches: 32
num_updates_per_batch: 8
num_eval_envs: 32
batch_size: 16
num_eval_iters: 5
# Setup
port: 12334
devices: [0, 1, 3, 4]
save_intermediate_ckpts: false
load_only: false
debug: false
# Policy
normalize_observations: true
reward_scaling: 1.
discounting: .97
learning_rate: 3e-4
entropy_cost: 2e-2
networks:
policy:
hidden: [32, 32]
value:
hidden: [32, 32]
The script that I use looks as follows:
load_dotenv()
REPO_PATH = Path(os.environ["REPO_PATH"])
CONFIG_PATH = str(REPO_PATH / "config")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
@hydra.main(config_path=CONFIG_PATH, config_name="config", version_base=None)
def train_cli(cfg: DictConfig) -> None:
run_idstr = HydraConfig.get().job.get('id').split("_")[-1]
job_id = int(run_idstr)
devices = cfg.devices
os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[job_id % len(devices)])
run_name = _get_run_name(cfg.log_dir, cfg.get("data_path")) if cfg.debug is False else "debug"
print(f"Starting run '{run_name}'; available devices: {os.environ['CUDA_VISIBLE_DEVICES']}")
# Train
make_inference_fn, params, metrics = _train_ppo(cfg, run_name)
inference_fn = make_inference_fn(params, deterministic=True)
# Cleanup
del inference_fn
gc.collect()
return float(metrics["eval/episode_reward"])
if __name__ == "__main__":
train_cli()
If I use a single job this runs fine. I would like to now use multiple jobs to speed up my search and map it across multiple GPUs. When I do this, my scripts runs fine until it gets stuck after the first or second batch during compilation, I suspect. How can I prevent this from happening?
In other words, how can I use optuna and joblib to tune the hyperparameters in my jax script?
Upvotes: 0
Views: 38