Reputation: 95
I saved trained policy with policy saver as following:
tf_policy_saver = policy_saver.PolicySaver(agent.policy)
tf_policy_saver.save(policy_dir)
I want to continue training with the saved policy. So I tried initializing the training with the saved policy, which caused some error.
agent = dqn_agent.DqnAgent(
tf_env.time_step_spec(),
tf_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
agent.policy=tf.compat.v2.saved_model.load(policy_dir)
ERROR:
File "C:/Users/Rohit/PycharmProjects/pythonProject/waypoint.py", line 172, in <module>
agent.policy=tf.compat.v2.saved_model.load('waypoints\\Two_rewards')
File "C:\Users\Rohit\anaconda3\envs\btp36\lib\site-packages\tensorflow\python\training\tracking\tracking.py", line 92, in __setattr__
super(AutoTrackable, self).__setattr__(name, value)
AttributeError: can't set attribute
I just want to save time retraining from first every time. How can I load saved policy and continue training??
Thanks in advance
Upvotes: 3
Views: 1431
Reputation: 815
Yes, as previously stated, you should use the Checkpointer to do this have a look at the example code below.
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
... # Train the agent
# Policy --> X
policy_checkpointer.save(global_step=epoch_counter.numpy())
When you later want to reload the policy you simply run the same initialization code.
agent = ... # Agent Definition
policy = agent.policy
# Policy --> Y1, possibly Y1==Y depending on agent class you are using, if it's DQN
# then they are different because of random initialization of network weights
policy_checkpointer = common.Checkpointer(ckpt_dir='path/to/dir',
policy=policy)
# Policy --> X
Upon creation, the policy_checkpointer
will automatically realize whether there are any preexisting checkpoints. If there are, it will update the value of the variables it is tracking automatically on creation.
A couple notes to make:
train_checkpointer = common.Checkpointer(ckpt_dir=first/dir,
agent=tf_agent, # tf_agent.TFAgent
train_step=train_step, # tf.Variable
epoch_counter=epoch_counter, # tf.Variable
metrics=metric_utils.MetricsGroup(
train_metrics, 'train_metrics'))
policy_checkpointer = common.Checkpointer(ckpt_dir=second/dir,
policy=agent.policy)
rb_checkpointer = common.Checkpointer(ckpt_dir=third/dir,
max_to_keep=1,
replay_buffer=replay_buffer # TFUniformReplayBuffer
)
DqnAgent
the agent.policy
and agent.collect_policy
are essentially wrappers around a QNetwork. The implication of this is shown in the code below (look at the comments on the state of the policy variable)agent = DqnAgent(...)
policy = agent.policy # Random initial policy ---> X
dataset = replay_buffer.as_dataset(...)
for data in dataset:
experience, _ = data
loss_agent_info = agent.train(experience=experience)
# policy variable stores a trained Policy object ---> Y
This happens because Tensors in TF are shared across your runtime. Therefore when you update your agent's QNetwork
weigths with agent.train
, those same weigths will implicitly update also in your policy
variable's QNetwork
. Indeed it's not that the policy
's Tensor get updated, but rather that they simply are the same as the Tensor's in your agent
.
Upvotes: 5