CarterB
CarterB

Reputation: 542

Extract agent from ray.tune

I have been using azure machine learning to train a reinforcement learning agent using ray.tune.

My training function is as follows:

    tune.run(
        run_or_experiment="PPO",
        config={
            "env": "Battery",
            "num_gpus" : 1,
            "num_workers": 13,
            "num_cpus_per_worker": 1,
            "train_batch_size": 1024,
            "num_sgd_iter": 20,
            'explore': True,
            'exploration_config': {'type': 'StochasticSampling'},
        },
        stop={'episode_reward_mean': 0.15},
        checkpoint_freq = 200,
        local_dir = 'second_checkpoints'
        
    )

How can I extract the agent from a checkpoint so I can visualise the actions on my gym environment as follows:

while not done:
    action, state, logits = agent.compute_action(obs, state)
    obs, reward, done, info = env.step(action)
    episode_reward += reward
    print('action: ' + str(action) + 'reward: ' + str(reward))


I understand I can use something like this:

analysis = tune.run('PPO",config={"max_iter": 10}, restore=last_ckpt)

But I am unsure on how to pull the computing actions (and reward) from the agent that exists within tune.run.

Upvotes: 0

Views: 506

Answers (1)

Michael Möbius
Michael Möbius

Reputation: 1050

The tune run is used to train the model. After the training you should have some checkpoint files. These files can be loaded and then played in your env.

agent = ppo.PPOTrainer(config=config, env=env_name)
agent.restore(checkpoint_file)
obs = env.reset()
action = agent.compute_action(obs)
obs, reward, done, info = env.step(action)

Upvotes: 1

Related Questions