Vladimir Belik
Vladimir Belik

Reputation: 400

Stable Baselines3 PPO() - how to change clip_range parameter during training?

I want to gradually decrease the clip_range (epsilon, exploration vs. exploitation parameter) throughout training in my PPO model.

I have tried to simply run "model.clip_range = new_value", but this doesn't work.

In the docs here , it says "clip_range (Union[float, Callable[[float], float]]) – Clipping parameter, it can be a function of the current progress remaining (from 1 to 0)."

Does anyone know how to actually change this parameter during training, or how to input "a function of the current progress remaining"?

Upvotes: 4

Views: 2293

Answers (2)

Tetramputechture
Tetramputechture

Reputation: 2921

Stable Baselines 3 has a utils module that exports the function you want, get_linear_fn (docs here)

Usage:

learning_rate = get_linear_fn(
    start=3e-4,
    end=5e-5,
    end_fraction=0.85
)

model = PPO(learning_rate=learning_rate, ...)

Upvotes: 0

Vladimir Belik
Vladimir Belik

Reputation: 400

I've solved the issue.

You need to have a slightly funky setup where a function outputs another function. At this link , they give the following example:

def linear_schedule(initial_value):
    """
    Linear learning rate schedule.
    :param initial_value: (float or str)
    :return: (function)
    """
    if isinstance(initial_value, str):
        initial_value = float(initial_value)

    def func(progress):
        """
        Progress will decrease from 1 (beginning) to 0
        :param progress: (float)
        :return: (float)
        """
        return progress * initial_value

    return func

So essentially, what you have to do is write a function, myscheduler(), which doesn't necessarily need inputs, and you need the output of that function to be another function which has "progress" (measured from 1 to 0 as training goes on) to be the only input. That "progress" value will be passed to the function by PPO itself. So, I suppose the "under the hood" order of events is something like:

  1. Your learning_rate scheduling function is called
  2. Your learning_rate scheduling function outputs a function which takes progress as input
  3. SB3's PPO (or other algorithm) input its current progress into that function
  4. Function outputs necessary learning_rate, and the model grabs it and goes with that output.

In my case, I wrote something like this:

def lrsched():
  def reallr(progress):
    lr = 0.003
    if progress < 0.85:
      lr = 0.0005
    if progress < 0.66:
      lr = 0.00025
    if progress < 0.33:
      lr = 0.0001
    return lr
  return reallr

Then, you use that function in the following way:

model = PPO(...learning_rate=lrsched())

Upvotes: 7

Related Questions