devboydan
devboydan

Reputation: 51

Issue implementing q-rnn in tf-agents

I have been trying to build a rl agent using tf-agents in tensorflow. I experienced the issue in a custom built environment but reproduced it using an official tf colab example. The problem occurs whenever I try to use QRnnNetwork as the network for the DqnAgent. The agent works fine with a regular qnetwork but there is a reshaping of the policy_state_spec when using qrnn. How would I remedy this?

This is the shape the policy_state_spec gets converted to, but the original shape is ()

ListWrapper([TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_0'), TensorSpec(shape=(16,), dtype=tf.float32, name='network_state_1')])

q_net = q_rnn_network.QRnnNetwork(
    train_env.observation_spec(),
    train_env.action_spec(),
    lstm_size=(16,),
    )
optimizer = tf.compat.v1.train.AdamOptimizer(learning_rate=learning_rate)
train_step_counter = tf.Variable(0)

agent = dqn_agent.DqnAgent(
    train_env.time_step_spec(),
    train_env.action_spec(),
    q_network=q_net,
    optimizer=optimizer,
    td_errors_loss_fn=common.element_wise_squared_loss,
    train_step_counter=train_step_counter)
agent.initialize()

collect_policy = agent.collect_policy
example_environment = tf_py_environment.TFPyEnvironment(
    suite_gym.load('CartPole-v0'))
time_step = example_environment.reset()


collect_policy.action(time_step)

I get this error:

TypeError: policy_state and policy_state_spec structures do not match:
  ()
vs.
  ListWrapper([., .])

Upvotes: 4

Views: 546

Answers (1)

Carlos Hidalgo
Carlos Hidalgo

Reputation: 23

I went into the code, and it seems that for RNN's, in the action(time_step, policy_state, seed) method you need to provide the state of the policy in the previous step, as the documentation says:

policy_state: A Tensor, or a nested dict, list or tuple of Tensors representing the previous policy_state. https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/GreedyPolicy#action

What you error:

TypeError: policy_state and policy_state_spec structures do not match:
  ()
 vs.
  ListWrapper([., .])

Is trying to say is that you should provide the RNN's internal state to the action method. I found an example on the documentation:

https://www.tensorflow.org/agents/api_docs/python/tf_agents/policies/TFPolicy#example_usage

The code it shows (as of August 8, 2021) is the following:

env = SomeTFEnvironment()
policy = TFRandomPolicy(env.time_step_spec(), env.action_spec())
# Or policy = agent.policy or agent.collect_policy

policy_state = policy.get_initial_state(env.batch_size)
time_step = env.reset()

while not time_step.is_last():
  policy_step = policy.action(time_step, policy_state)
  time_step = env.step(policy_step.action)

  policy_state = policy_step.state
  # policy_step.info may contain side info for logging, such as action log
  # probabilities.

If you implement your code in this fashion, it might work!

Upvotes: 1

Related Questions