HansDoe
HansDoe

Reputation: 52

Slow compilation / Deadlocks when combining joblib and jax

I have some code to train a RL agent in jax. The code runs fine.

To tune the hyperparameters I would like to use the optuna plugin of hydra since my project is based on the latter. To this end, I created the following config file:

defaults:
  - base_conf
  - env: mjx_reacher
  - override hydra/launcher: joblib
  - override hydra/sweeper: optuna
  - _self_

hydra:
  sweeper:
    _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
    direction: maximize
    n_jobs: 1
    n_trials: 50
    params:
      unroll_length: range(1, 50)
      num_minibatches: choice(8, 16, 32, 64)
      num_updates_per_batch: choice(2, 4, 8, 16)
      batch_size: choice(16, 64, 256)
      learning_rate: interval(1e-4, 5e-3)
  launcher:
    backend: processes
    prefer: processes
    n_jobs: ${hydra.sweeper.n_jobs}

# Data gen / alg params
grad_clip_norm: 1.
num_timesteps: 500_000
seed: 123
num_evals: 10
unroll_length: 5
num_minibatches: 32
num_updates_per_batch: 8
num_eval_envs: 32
batch_size: 16
num_eval_iters: 5

# Setup
port: 12334
devices: [0, 1, 3, 4]
save_intermediate_ckpts: false
load_only: false
debug: false

# Policy
normalize_observations: true
reward_scaling: 1.
discounting: .97
learning_rate: 3e-4
entropy_cost: 2e-2
networks:
  policy:
    hidden: [32, 32]
  value:
    hidden: [32, 32]

The script that I use looks as follows:

load_dotenv()
REPO_PATH = Path(os.environ["REPO_PATH"])
CONFIG_PATH = str(REPO_PATH / "config")
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"


@hydra.main(config_path=CONFIG_PATH, config_name="config", version_base=None)
def train_cli(cfg: DictConfig) -> None:
    run_idstr = HydraConfig.get().job.get('id').split("_")[-1]
    job_id = int(run_idstr)
    devices = cfg.devices
    os.environ["CUDA_VISIBLE_DEVICES"] = str(devices[job_id % len(devices)])
    run_name = _get_run_name(cfg.log_dir, cfg.get("data_path")) if cfg.debug is False else "debug"
    print(f"Starting run '{run_name}'; available devices: {os.environ['CUDA_VISIBLE_DEVICES']}")

    # Train
    make_inference_fn, params, metrics = _train_ppo(cfg, run_name)
    inference_fn = make_inference_fn(params, deterministic=True)
        

    # Cleanup
    del inference_fn
    gc.collect()

    return float(metrics["eval/episode_reward"])


if __name__ == "__main__":
    train_cli()

If I use a single job this runs fine. I would like to now use multiple jobs to speed up my search and map it across multiple GPUs. When I do this, my scripts runs fine until it gets stuck after the first or second batch during compilation, I suspect. How can I prevent this from happening?

In other words, how can I use optuna and joblib to tune the hyperparameters in my jax script?

Upvotes: 0

Views: 38

Answers (0)

Related Questions