Reputation: 542
I have been using azure machine learning to train a reinforcement learning agent using ray.tune.
My training function is as follows:
tune.run(
run_or_experiment="PPO",
config={
"env": "Battery",
"num_gpus" : 1,
"num_workers": 13,
"num_cpus_per_worker": 1,
"train_batch_size": 1024,
"num_sgd_iter": 20,
'explore': True,
'exploration_config': {'type': 'StochasticSampling'},
},
stop={'episode_reward_mean': 0.15},
checkpoint_freq = 200,
local_dir = 'second_checkpoints'
)
How can I extract the agent from a checkpoint so I can visualise the actions on my gym environment as follows:
while not done:
action, state, logits = agent.compute_action(obs, state)
obs, reward, done, info = env.step(action)
episode_reward += reward
print('action: ' + str(action) + 'reward: ' + str(reward))
I understand I can use something like this:
analysis = tune.run('PPO",config={"max_iter": 10}, restore=last_ckpt)
But I am unsure on how to pull the computing actions (and reward) from the agent that exists within tune.run.
Upvotes: 0
Views: 506
Reputation: 1050
The tune run is used to train the model. After the training you should have some checkpoint files. These files can be loaded and then played in your env.
agent = ppo.PPOTrainer(config=config, env=env_name)
agent.restore(checkpoint_file)
obs = env.reset()
action = agent.compute_action(obs)
obs, reward, done, info = env.step(action)
Upvotes: 1