Reputation: 51
I have been trying to build a rl agent using tf-agents in tensorflow. I experienced the issue in a custom built environment but reproduced it using an official tf colab example. The problem occurs whenever I try to use QRnnNetwork as the network for the DqnAgent. The agent works fine with a regular qnetwork but there is a reshaping of the policy_state_spec when using qrnn. How would I remedy this?
This is the shape the policy_state_spec gets converted to, but the original shape is ()
ListWrapper([TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_0'), TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_1')])
q_net = q_rnn_network.QRnnNetwork(
train_env.observation_spec(),
train_env.action_spec(),
lstm_size=(16,),
)
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)
agent = dqn_agent.DqnAgent(
train_env.time_step_spec(),
train_env.action_spec(),
q_network=q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=train_step_counter)
agent.initialize()
collect_policy = agent.collect_policy
example_environment = tf_py_environment.TFPyEnvironment(
suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()
collect_policy.action(time_step)
I get this error:
TypeError: policy_state and policy_state_spec structures do not match:
()
vs.
ListWrapper([., .])
Upvotes: 4
Views: 546
Reputation: 23
I went into the code, and it seems that for RNN's, in the action(time_step, policy_state, seed)
method you need to provide the state of the policy in the previous step, as the documentation says:
policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/GreedyPolicy#action
What you error:
TypeError: policy_state and policy_state_spec structures do not match:
()
vs.
ListWrapper([., .])
Is trying to say is that you should provide the RNN's internal state to the action
method. I found an example on the documentation:
https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/TFPolicy#example_usage
The code it shows (as of August 8, 2021) is the following:
env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy
policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()
while not time_step.is_last():
policy_step = policy.action(time_step, policy_state)
time_step = env.step(policy_step.action)
policy_state = policy_step.state
# policy_step.info may contain side info for logging, such as action log
# probabilities.
If you implement your code in this fashion, it might work!
Upvotes: 1