rohit deraj
rohit deraj

Reputation: 95

Cant load saved policy (TF-agents)

I saved trained policy with policy saver as following:

  tf_policy_saver = policy_saver.PolicySaver(agent.policy)
  tf_policy_saver.save(policy_dir)

I want to continue training with the saved policy. So I tried initializing the training with the saved policy, which caused some error.

agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)

agent.initialize()

agent.policy=tf.compat.v2.saved_model.load(policy_dir)

ERROR:

  File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\\Two_rewards')


File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
    super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute

I just want to save time retraining from first every time. How can I load saved policy and continue training??

Thanks in advance

Upvotes: 3

Views: 1431

Answers (2)

Federico Malerba
Federico Malerba

Reputation: 815

Yes, as previously stated, you should use the Checkpointer to do this have a look at the example code below.

agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
                                          policy=policy)

... # Train the agent

# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())

When you later want to reload the policy you simply run the same initialization code.

agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
#               then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
                                          policy=policy)
# Policy --> X

Upon creation, the policy_checkpointer will automatically realize whether there are any preexisting checkpoints. If there are, it will update the value of the variables it is tracking automatically on creation.

A couple notes to make:

  1. You can save with the checkpointer a lot more than just the policy, and indeed I recommend doing so. TF-Agent's Checkpointer object is extremely flexible, e.g.:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
                                         agent=tf_agent,               # tf_agent.TFAgent
                                         train_step=train_step,        # tf.Variable
                                         epoch_counter=epoch_counter,  # tf.Variable
                                         metrics=metric_utils.MetricsGroup(
                                                 train_metrics, 'train_metrics'))

policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
                                          policy=agent.policy)

rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
                                      max_to_keep=1,
                                      replay_buffer=replay_buffer  # TFUniformReplayBuffer
                                      )
  1. Note that in the case of a DqnAgent the agent.policy and agent.collect_policy are essentially wrappers around a QNetwork. The implication of this is shown in the code below (look at the comments on the state of the policy variable)
agent = DqnAgent(...)
policy = agent.policy      # Random initial policy ---> X

dataset = replay_buffer.as_dataset(...)
for data in dataset:
   experience, _ = data
   loss_agent_info = agent.train(experience=experience)

# policy variable stores a trained Policy object ---> Y

This happens because Tensors in TF are shared across your runtime. Therefore when you update your agent's QNetwork weigths with agent.train, those same weigths will implicitly update also in your policy variable's QNetwork. Indeed it's not that the policy's Tensor get updated, but rather that they simply are the same as the Tensor's in your agent.

Upvotes: 5

m.a.a.
m.a.a.

Reputation: 137

You should check out Checkpointer for that purpose.

Upvotes: 1

Related Questions