Reputation: 33
I have a problem with this simple code. I am looking for an experience replay code that is compatible with graph operations and TF Functions.
import tensorflow as tf
import numpy as np
!pip install tf_agents
import tf_agents
from tf_agents.replay_buffers import tf_uniform_replay_buffer
data_spec = (
tf.TensorSpec((5,), tf.float32, 'state'),
tf.TensorSpec((), tf.float32, 'action')
)
print(data_spec)
buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
data_spec=data_spec,
batch_size=1,
max_length=1000
)
state = tf.constant(tf.ones(5), dtype=tf.float32)
action = tf.constant(3.)
save = (state, action)
buffer.add_batch(save)
Error:
InvalidArgumentError Traceback (most recent call last)
in ()
----> 1 buffer.add_batch(save)
6 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/ops.py in raise_from_not_ok_status(e, name)
7105 def raise_from_not_ok_status(e, name):
7106 e.message += (" name: " + name if name is not None else "")
-> 7107 raise core._status_to_exception(e) from None # pylint: disable=protected-access
7108
7109
InvalidArgumentError: Must have updates.shape = indices.shape + params.shape[1:] or updates.shape = [], got updates.shape [5], indices.shape [1], params.shape [1000,5] [Op:ResourceScatterUpdate]
Upvotes: 1
Views: 329