Reputation: 738
I am using tf-agents library to build a contextual bandit.
For this I am building a custom environment.
I am creating a banditpyenvironment and wrapping it in the TFpyenvironment.
The tfpyenvironment automatically adds the batch size dimension (in observation spec). I need to account for this batch size dimension in the _observe and _apply_Action methods. Since depending on the batch size, I should provide the required (batch size) number of observations (for observe) and also as per batch size, I should take in batch size number of actions and should provide the rewards(for apply action).
I am unable to find a single example on how to tell the tfenvironment what the batch size, without letting automatically add a 1 to the first dimension. Can someone please clarify
def __init__(self, batch_size):
self.batchsize=batch_size
observation_spec = BoundedTensorSpec(
(2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
action_spec = BoundedTensorSpec(
shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')
super(SampleEnvironment, self).__init__(observation_spec, action_spec)
def _observe(self):
batch=[]
for i in range(self.batchsize):
each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
batch.append(each)
self.observation=np.array(batch)
print("in observe",self.observation)
return np.array(self.observation)
When I try to somehow account for the batchsize in the observe method like above (using a for loop for the batch size), the tfenvironment is again adding 1 to the first dimension as batchsize. Is there a way to automatically tell the environment that the batch is say 3, instead of it automatically adding 1. At the same time, how would I account for this batch size in replay buffer and agents
Upvotes: 2
Views: 592
Reputation: 738
This can be done using the BatchedPyEnvironment class as show in the example below. Looks like the bandit environment from above is a non batched environment.
SampleEnvironment in below is the banditpyenvironment which is shown in the question
batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)
Upvotes: 1