Visgean Skeloru
Visgean Skeloru

Reputation: 2263

Stable baselines saving PPO model and retraining it again

Hello I am using Stable baselines package (https://stable-baselines.readthedocs.io/), specifically I am using the PPO2 and I am not sure how to properly save my model... I trained it for 6 virtual days and got my average return to around 300, then I have decided that this is not enough for me so I trained the model for another 6 days. But when I looked at the training statistics the second training return per episode started at around 30. This suggest that it did not save all parameters.

this is how I save use the package:

def make_env_init(env_id, rank, seed=0):
    """
    Utility function for multiprocessed env.

    :param env_id: (str) the environment ID
    :param seed: (int) the inital seed for RNG
    :param rank: (int) index of the subprocess
    """

    def env_init():
        # Important: use a different seed for each environment
        env = gym.make(env_id, connection=blt.DIRECT)
        env.seed(seed + rank)
        return env

    set_global_seeds(seed)
    return env_init



envs = VecNormalize(SubprocVecEnv([make_env_init(f'envs:{env_name}', i) for i in range(processes)]), norm_reward=False)

if os.path.exists(folder / 'model_dump.zip'):
    model = PPO2.load(folder / 'model_dump.zip', envs, **ppo_kwards)
else:
    model = PPO2(MlpPolicy, envs, **ppo_kwards)

model.learn(total_timesteps=total_timesteps, callback=callback)
model.save(folder / 'model_dump.zip')

Upvotes: 2

Views: 6767

Answers (1)

tonyTenerife
tonyTenerife

Reputation: 39

The way you saved the model is correct. The training is not a monotonous process: it can also show much worse results after a further training.

What you can do, first of all is to write logs of the progress:

model = PPO2(MlpPolicy, envs, tensorboard_log="./logs/progress_tensorboard/")

In order to see the log, run in terminal:

tensorboard --port 6004 --logdir ./logs/progress_tensorboard/

it will give you the link to the board, which you can then open in a browser (e.g. http://pc0259:6004/)

Secondly, you can make snapshots of the model each X steps:

from stable_baselines.common.callbacks import CheckpointCallback

checkpoint_callback = CheckpointCallback(save_freq=1e4, save_path='./model_checkpoints/')
model.learn(total_timesteps=total_timesteps, callback=[callback, checkpoint_callback])

Combining it with the log, you can pick up the model which performed best!

Upvotes: 3

Related Questions