Boris W
Boris W

Reputation: 11

Received incompatible tensor at flattened index 4 from table 'uniform_table'

I'm trying to adapt the TensorFlow Agents tutorial to a custom environment. It's not very complicated and meant to teach me how this works. The game is basically a 21x21 grid with tokens the agent can collect for a reward by walking around. I can validate the environment, the agent, and the replay buffer, but when I try to train the model, i get an error message (see bottom). Any advice would be welcome !

The agent class is:

import numpy as np
import random
from IPython.display import clear_output
import time


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import abc
import tensorflow as tf
import numpy as np

from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts

class cGame (py_environment.PyEnvironment):
    def __init__(self):
        self.xdim = 21
        self.ydim = 21
        self.mmap = np.array([[0]*self.xdim]*self.ydim)
        self._turnNumber = 0
        self.playerPos = {"x":1, "y":1}
        self.totalScore = 0
        self.reward = 0.0
        self.input = 0
        self.addRewardEveryNTurns = 4
        self.addBombEveryNTurns = 3
        self._episode_ended = False
    
        ## player = 13
        ## bomb   = 14
        
        self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=3, name='action')
        self._observation_spec = array_spec.BoundedArraySpec(shape = (441,),  minimum=np.array([-1]*441), maximum = np.array([20]*441), dtype=np.int32, name='observation')  #(self.xdim, self.ydim)  , self.mmap.shape,  minimum = -1, maximum = 10

    def action_spec(self):
        return self._action_spec

    def observation_spec(self):
        return self._observation_spec

    def addMapReward(self):
        dx = random.randint(1, self.xdim-2)
        dy = random.randint(1, self.ydim-2)
        if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
            self.mmap[dy][dx] = random.randint(1, 9)
        return True
    
    def addBombToMap(self):
        dx = random.randint(1, self.xdim-2)
        dy = random.randint(1, self.ydim-2)
        if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
            self.mmap[dy][dx] = 14
        return True
        
    def _reset (self):
        self.mmap = np.array([[0]*self.xdim]*self.ydim)
        for y in range(self.ydim):
            self.mmap[y][0] = -1
            self.mmap[y][self.ydim-1] = -1
        for x in range(self.xdim):
            self.mmap[0][x] = -1
            self.mmap[self.ydim-1][x] = -1
            
        self.playerPos["x"] = random.randint(1, self.xdim-2)
        self.playerPos["y"] = random.randint(1, self.ydim-2)
        self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
        
        for z in range(10):
            ## place 10 targets
            self.addMapReward()
        for z in range(5):
            ## place 5 bombs
            ## bomb   = 14
            self.addBombToMap()
        self._turnNumber = 0
        self._episode_ended = False
        #return ts.restart (self.mmap)
        dap = ts.restart(np.array(self.mmap, dtype=np.int32).flatten())
        return (dap)
            
    def render(self, mapToRender):
        mapToRender.reshape(21,21)
        for y  in range(self.ydim):
            o =""
            for x in range(self.xdim):
                if mapToRender[y][x]==-1:
                    o=o+"#"
                elif mapToRender[y][x]>0 and mapToRender[y][x]<10:
                    o=o+str(mapToRender[y][x])
                elif mapToRender[y][x] == 13:
                    o=o+"@"
                elif mapToRender[y][x] == 14:
                    o=o+"*"
                else:
                    o=o+" "
            print (o)
        print ('TOTAL SCORE:', self.totalScore, 'LAST TURN SCORE:', self.reward)
        return True
    
    def getInput(self):
        self.input = 0
        i = input()
        if i == 'w' or i == '0':
            print ('going N')
            self.input = 1
        if i == 's' or i == '1':
            print ('going S')
            self.input = 2
        if i == 'a' or i == '2':
            print ('going W')
            self.input = 3
        if i == 'd' or i == '3':
            print ('going E')
            self.input = 4
        if i == 'x':
            self.input = 5
        return self.input
    
    def processMove(self):
        
        self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
        self.reward = 0
        if self.input == 0:
            self.playerPos["y"] -=1
        if self.input == 1:
            self.playerPos["y"] +=1
        if self.input == 2:
            self.playerPos["x"] -=1
        if self.input == 3:
            self.playerPos["x"] +=1
        
        cloc = self.mmap[self.playerPos["y"]][self.playerPos["x"]]
        
        if  cloc == -1 or cloc ==14:
            self.totalScore = 0
            self.reward = -99
        
        if cloc >0 and cloc < 10:
            self.totalScore += cloc
            self.reward = cloc
            self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0

        self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13

        self.render(self.mmap)
    
    def runTurn(self):
        clear_output(wait=True)
        if self._turnNumber % self.addRewardEveryNTurns == 0:
            self.addMapReward()
        if self._turnNumber % self.addBombEveryNTurns == 0:
            self.addBombToMap()
        
        self.getInput()
        self.processMove()
        self._turnNumber +=1
        if self.reward == -99:
            self._turnNumber +=1
            self._reset()
            self.totalScore = 0
            self.render(self.mmap)
        return (self.reward)
    
    def _step (self, action):
        
        if self._episode_ended == True:
            return self._reset() 
        
        clear_output(wait=True)
        if self._turnNumber % self.addRewardEveryNTurns == 0:
            self.addMapReward()
        if self._turnNumber % self.addBombEveryNTurns == 0:
            self.addBombToMap()

        ## make sure action does produce exceed range
        #if action > 5 or action <1:
        #    action =0
        self.input = action  ## value 1 to 4
        self.processMove()
        self._turnNumber +=1
        
        if self.reward == -99:
            self._turnNumber +=1
            self._episode_ended = True
            #self._reset()
            self.totalScore = 0
            self.render(self.mmap)
            return ts.termination(np.array(self.mmap, dtype=np.int32).flatten(), reward = self.reward)
        else:
            return ts.transition(np.array(self.mmap, dtype=np.int32).flatten(), reward = self.reward) #, discount = 1.0
    
    def run (self):
        self._reset()
        self.render(self.mmap)
        while (True):
            self.runTurn()
            if self.input == 5:
                return ("EXIT on input x ")

env = cGame()

The class I want to use for training the model is:

from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
import reverb
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.drivers import py_driver
from tf_agents.environments import BatchedPyEnvironment


class mTrainer:
    def __init__ (self):
        
        self.train_env = tf_py_environment.TFPyEnvironment(cGame())
        self.eval_env  = tf_py_environment.TFPyEnvironment(cGame())
        
        self.num_iterations = 20000 # @param {type:"integer"}
        self.initial_collect_steps = 100  # @param {type:"integer"}
        self.collect_steps_per_iteration = 100 # @param {type:"integer"}
        self.replay_buffer_max_length = 100000  # @param {type:"integer"}
        self.batch_size = 64  # @param {type:"integer"}
        self.learning_rate = 1e-3  # @param {type:"number"}
        self.log_interval = 200  # @param {type:"integer"}
        self.num_eval_episodes = 10  # @param {type:"integer"}
        self.eval_interval = 1000  # @param {type:"integer"}

        
        
    def createAgent(self):
        fc_layer_params = (100, 50)
        action_tensor_spec = tensor_spec.from_spec(self.train_env.action_spec())
        num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

        def dense_layer(num_units):
            return tf.keras.layers.Dense(
                num_units,
                activation=tf.keras.activations.relu,
                kernel_initializer=tf.keras.initializers.VarianceScaling(
                    scale=2.0, mode='fan_in', distribution='truncated_normal'))

        dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
        q_values_layer = tf.keras.layers.Dense(
            num_actions,
            activation=None,
            kernel_initializer=tf.keras.initializers.RandomUniform(
                minval=-0.03, maxval=0.03),
            bias_initializer=tf.keras.initializers.Constant(-0.2))
        
        self.q_net = sequential.Sequential(dense_layers + [q_values_layer])
        
        optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
        #rain_step_counter = tf.Variable(0)

        self.agent = dqn_agent.DqnAgent(
            time_step_spec = self.train_env.time_step_spec(),
            action_spec = self.train_env.action_spec(),
            q_network=self.q_net,
            optimizer=optimizer,
            td_errors_loss_fn=common.element_wise_squared_loss,
            train_step_counter=tf.Variable(0))

        self.agent.initialize()
        
        self.eval_policy = self.agent.policy
        self.collect_policy = self.agent.collect_policy
        self.random_policy = random_tf_policy.RandomTFPolicy(self.train_env.time_step_spec(),self.train_env.action_spec())
        return True

    def compute_avg_return(self, environment, policy, num_episodes=10):
        #mT.compute_avg_return(mT.eval_env, mT.random_policy, 50)
        total_return = 0.0
        for _ in range(num_episodes):
            time_step = environment.reset()
            episode_return = 0.0
            while not time_step.is_last():
                action_step = policy.action(time_step)
                time_step = environment.step(action_step.action)
                episode_return += time_step.reward
            total_return += episode_return
        avg_return = total_return / num_episodes
        print ('average return :', avg_return.numpy()[0])
        return avg_return.numpy()[0]

    def create_replaybuffer(self):

        table_name = 'uniform_table'
        replay_buffer_signature = tensor_spec.from_spec(self.agent.collect_data_spec)
        replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)

        table = reverb.Table(table_name,
                             max_size=self.replay_buffer_max_length,
                             sampler=reverb.selectors.Uniform(),
                             remover=reverb.selectors.Fifo(),
                             rate_limiter=reverb.rate_limiters.MinSize(1),
                             signature=replay_buffer_signature)

        reverb_server = reverb.Server([table])

        self.replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
                            self.agent.collect_data_spec,
                            table_name=table_name,
                            sequence_length=2,
                            local_server=reverb_server)

        self.rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
                            self.replay_buffer.py_client,
                            table_name,
                            sequence_length=2)
    
        self.dataset = self.replay_buffer.as_dataset(num_parallel_calls=3,sample_batch_size=self.batch_size,num_steps=2).prefetch(3)
        self.iterator = iter(self.dataset)

    def testReplayBuffer(self):
        py_driver.PyDriver(
            self.train_env,
            py_tf_eager_policy.PyTFEagerPolicy(
              self.random_policy, 
                use_tf_function=True),
            [self.rb_observer],
            max_steps=self.initial_collect_steps).run(self.train_env.reset())        
        
    def trainAgent(self):
        
        print (self.collect_policy)
        # Create a driver to collect experience.
        collect_driver = py_driver.PyDriver(
            self.train_env, 
            py_tf_eager_policy.PyTFEagerPolicy(
               self.agent.collect_policy,
                batch_time_steps=False,
                use_tf_function=True),
            [self.rb_observer],
            max_steps=self.collect_steps_per_iteration)

        
        # Reset the environment.
        time_step = self.train_env.reset()
        
        for _ in range(self.num_iterations):

            # Collect a few steps and save to the replay buffer.
            time_step, _ = collect_driver.run(time_step)

            # Sample a batch of data from the buffer and update the agent's network.
            experience, unused_info = next(self.iterator)
            train_loss = agent.train(experience).loss

            step = agent.train_step_counter.numpy()

            if step % log_interval == 0:
                print('step = {0}: loss = {1}'.format(step, train_loss))

            if step % eval_interval == 0:
                avg_return = self.compute_avg_return(self.eval_env, agent.policy, num_eval_episodes)
                print('step = {0}: Average Return = {1}'.format(step, avg_return))
                self.returns.append(avg_return)
        
        
        
    
    def run(self):
        self.createAgent()
        #self.compute_avg_return(self.train_env,self.eval_policy)
        self.create_replaybuffer()
        #self.testReplayBuffer()
        self.trainAgent()
        return True

mT = mTrainer()
mT.run()

It produces this error message:

InvalidArgumentError: Received incompatible tensor at flattened index 4 from table 'uniform_table'. Specification has (dtype, shape): (int32, [?]). Tensor has (dtype, shape): (int32, [2,1]). Table signature: 0: Tensor<name: 'key', dtype: uint64, shape: []>, 1: Tensor<name: 'probability', dtype: double, shape: []>, 2: Tensor<name: 'table_size', dtype: int64, shape: []>, 3: Tensor<name: 'priority', dtype: double, shape: []>, 4: Tensor<name: 'step_type/step_type', dtype: int32, shape: [?]>, 5: Tensor<name: 'observation/observation', dtype: int32, shape: [?,441]>, 6: Tensor<name: 'action/action', dtype: int32, shape: [?]>, 7: Tensor<name: 'next_step_type/step_type', dtype: int32, shape: [?]>, 8: Tensor<name: 'reward/reward', dtype: float, shape: [?]>, 9: Tensor<name: 'discount/discount', dtype: float, shape: [?]> [Op:IteratorGetNext]

Upvotes: 1

Views: 770

Answers (1)

Abhi
Abhi

Reputation: 126

I got stuck with a similar issue, the reason for it is that, you are using tensorflow environment as the parameter of the PyDriver to collect the data. Tensorflow environment adds a batch dimension to all the tensors that it produces, therefore, each time_step generated will have an additional dimension whose value will be 1.

Now, when you retrieve the data from the replay buffer, each of time_step will have an additional dimension and it is not compatible with the data that the train function of the agent is expecting, hence the error.

You need to use a python environment here in order to collect the data with right dimension. Also, now you don't have to use batch_time_steps = False.

I am not sure how to collect the data with right dimensions with a tensorflow environment so I have modified your code a bit to allow data collection using python environment and it should run now.

PS - There were a few trivial bugs in the code you posted (ex. using log_interval instead of self.log_interval etc).

Agent Class `

    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function

    import numpy as np
    import random
    from IPython.display import clear_output
    import time



    import abc
    import tensorflow as tf
    import numpy as np

    from tf_agents.environments import py_environment
    from tf_agents.environments import tf_environment
    from tf_agents.environments import tf_py_environment
    from tf_agents.environments import utils
    from tf_agents.specs import array_spec
    from tf_agents.environments import wrappers
    from tf_agents.environments import suite_gym
    from tf_agents.trajectories import time_step as ts


    class cGame(py_environment.PyEnvironment):
        def __init__(self):
            self.xdim = 21
            self.ydim = 21
            self.mmap = np.array([[0] * self.xdim] * self.ydim)
            self._turnNumber = 0
            self.playerPos = {"x": 1, "y": 1}
            self.totalScore = 0
            self.reward = 0.0
            self.input = 0
            self.addRewardEveryNTurns = 4
            self.addBombEveryNTurns = 3
            self._episode_ended = False

            ## player = 13
            ## bomb   = 14

            self._action_spec = array_spec.BoundedArraySpec(shape=(),
                                                            dtype=np.int32,
                                                            minimum=0, maximum=3,
                                                            name='action')
            self._observation_spec = array_spec.BoundedArraySpec(shape=(441,),
                                                                 minimum=np.array(
                                                                     [-1] * 441),
                                                                 maximum=np.array(
                                                                     [20] * 441),
                                                                 dtype=np.int32,
                                                                 name='observation')  # (self.xdim, self.ydim)  , self.mmap.shape,  minimum = -1, maximum = 10

        def action_spec(self):
            return self._action_spec

        def observation_spec(self):
            return self._observation_spec

        def addMapReward(self):
            dx = random.randint(1, self.xdim - 2)
            dy = random.randint(1, self.ydim - 2)
            if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
                self.mmap[dy][dx] = random.randint(1, 9)
            return True

        def addBombToMap(self):
            dx = random.randint(1, self.xdim - 2)
            dy = random.randint(1, self.ydim - 2)
            if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
                self.mmap[dy][dx] = 14
            return True

        def _reset(self):
            self.mmap = np.array([[0] * self.xdim] * self.ydim)
            for y in range(self.ydim):
                self.mmap[y][0] = -1
                self.mmap[y][self.ydim - 1] = -1
            for x in range(self.xdim):
                self.mmap[0][x] = -1
                self.mmap[self.ydim - 1][x] = -1

            self.playerPos["x"] = random.randint(1, self.xdim - 2)
            self.playerPos["y"] = random.randint(1, self.ydim - 2)
            self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13

            for z in range(10):
                ## place 10 targets
                self.addMapReward()
            for z in range(5):
                ## place 5 bombs
                ## bomb   = 14
                self.addBombToMap()
            self._turnNumber = 0
            self._episode_ended = False
            # return ts.restart (self.mmap)
            dap = ts.restart(np.array(self.mmap, dtype=np.int32).flatten())
            return (dap)

        def render(self, mapToRender):
            mapToRender.reshape(21, 21)
            for y in range(self.ydim):
                o = ""
                for x in range(self.xdim):
                    if mapToRender[y][x] == -1:
                        o = o + "#"
                    elif mapToRender[y][x] > 0 and mapToRender[y][x] < 10:
                        o = o + str(mapToRender[y][x])
                    elif mapToRender[y][x] == 13:
                        o = o + "@"
                    elif mapToRender[y][x] == 14:
                        o = o + "*"
                    else:
                        o = o + " "
                print(o)
            print('TOTAL SCORE:', self.totalScore, 'LAST TURN SCORE:', self.reward)
            return True

        def getInput(self):
            self.input = 0
            i = input()
            if i == 'w' or i == '0':
                print('going N')
                self.input = 1
            if i == 's' or i == '1':
                print('going S')
                self.input = 2
            if i == 'a' or i == '2':
                print('going W')
                self.input = 3
            if i == 'd' or i == '3':
                print('going E')
                self.input = 4
            if i == 'x':
                self.input = 5
            return self.input

        def processMove(self):

            self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
            self.reward = 0
            if self.input == 0:
                self.playerPos["y"] -= 1
            if self.input == 1:
                self.playerPos["y"] += 1
            if self.input == 2:
                self.playerPos["x"] -= 1
            if self.input == 3:
                self.playerPos["x"] += 1

            cloc = self.mmap[self.playerPos["y"]][self.playerPos["x"]]

            if cloc == -1 or cloc == 14:
                self.totalScore = 0
                self.reward = -99

            if cloc > 0 and cloc < 10:
                self.totalScore += cloc
                self.reward = cloc
                self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0

            self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13

            self.render(self.mmap)

        def runTurn(self):
            clear_output(wait=True)
            if self._turnNumber % self.addRewardEveryNTurns == 0:
                self.addMapReward()
            if self._turnNumber % self.addBombEveryNTurns == 0:
                self.addBombToMap()

            self.getInput()
            self.processMove()
            self._turnNumber += 1
            if self.reward == -99:
                self._turnNumber += 1
                self._reset()
                self.totalScore = 0
                self.render(self.mmap)
            return (self.reward)

        def _step(self, action):

            if self._episode_ended == True:
                return self._reset()

            clear_output(wait=True)
            if self._turnNumber % self.addRewardEveryNTurns == 0:
                self.addMapReward()
            if self._turnNumber % self.addBombEveryNTurns == 0:
                self.addBombToMap()

            ## make sure action does produce exceed range
            # if action > 5 or action <1:
            #    action =0
            self.input = action  ## value 1 to 4
            self.processMove()
            self._turnNumber += 1

            if self.reward == -99:
                self._turnNumber += 1
                self._episode_ended = True
                # self._reset()
                self.totalScore = 0
                self.render(self.mmap)
                return ts.termination(np.array(self.mmap, dtype=np.int32).flatten(),
                                      reward=self.reward)
            else:
                return ts.transition(np.array(self.mmap, dtype=np.int32).flatten(),
                                     reward=self.reward)  # , discount = 1.0

        def run(self):
            self._reset()
            self.render(self.mmap)
            while (True):
                self.runTurn()
                if self.input == 5:
                    return ("EXIT on input x ")


    env = cGame()

`

Driver Code `

    from tf_agents.specs import tensor_spec
    from tf_agents.networks import sequential
    from tf_agents.agents.dqn import dqn_agent
    from tf_agents.utils import common
    from tf_agents.policies import py_tf_eager_policy
    from tf_agents.policies import random_tf_policy
    import reverb
    from tf_agents.replay_buffers import reverb_replay_buffer
    from tf_agents.replay_buffers import reverb_utils
    from tf_agents.trajectories import trajectory
    from tf_agents.drivers import py_driver
    from tf_agents.environments import BatchedPyEnvironment
    
    
    class mTrainer:
        def __init__(self):
    
            self.returns = None
            self.train_env = tf_py_environment.TFPyEnvironment(cGame())
            self.eval_env = tf_py_environment.TFPyEnvironment(cGame())
    
            self.num_iterations = 20000  # @param {type:"integer"}
            self.initial_collect_steps = 100  # @param {type:"integer"}
            self.collect_steps_per_iteration = 100  # @param {type:"integer"}
            self.replay_buffer_max_length = 100000  # @param {type:"integer"}
            self.batch_size = 64  # @param {type:"integer"}
            self.learning_rate = 1e-3  # @param {type:"number"}
            self.log_interval = 200  # @param {type:"integer"}
            self.num_eval_episodes = 10  # @param {type:"integer"}
            self.eval_interval = 1000  # @param {type:"integer"}
    
        def createAgent(self):
            fc_layer_params = (100, 50)
            action_tensor_spec = tensor_spec.from_spec(self.train_env.action_spec())
            num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
    
            def dense_layer(num_units):
                return tf.keras.layers.Dense(
                    num_units,
                    activation=tf.keras.activations.relu,
                    kernel_initializer=tf.keras.initializers.VarianceScaling(
                        scale=2.0, mode='fan_in', distribution='truncated_normal'))
    
            dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
            q_values_layer = tf.keras.layers.Dense(
                num_actions,
                activation=None,
                kernel_initializer=tf.keras.initializers.RandomUniform(
                    minval=-0.03, maxval=0.03),
                bias_initializer=tf.keras.initializers.Constant(-0.2))
    
            self.q_net = sequential.Sequential(dense_layers + [q_values_layer])
    
            optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
            # rain_step_counter = tf.Variable(0)
    
            self.agent = dqn_agent.DqnAgent(
                time_step_spec=self.train_env.time_step_spec(),
                action_spec=self.train_env.action_spec(),
                q_network=self.q_net,
                optimizer=optimizer,
                td_errors_loss_fn=common.element_wise_squared_loss,
                train_step_counter=tf.Variable(0))
    
            self.agent.initialize()
    
            self.eval_policy = self.agent.policy
            self.collect_policy = self.agent.collect_policy
            self.random_policy = random_tf_policy.RandomTFPolicy(
                self.train_env.time_step_spec(), self.train_env.action_spec())
            return True
    
        def compute_avg_return(self, environment, policy, num_episodes=10):
            # mT.compute_avg_return(mT.eval_env, mT.random_policy, 50)
            total_return = 0.0
            for _ in range(num_episodes):
                time_step = environment.reset()
                episode_return = 0.0
                while not time_step.is_last():
                    action_step = policy.action(time_step)
                    time_step = environment.step(action_step.action)
                    episode_return += time_step.reward
                total_return += episode_return
            avg_return = total_return / num_episodes
            print('average return :', avg_return.numpy()[0])
            return avg_return.numpy()[0]
    
        def create_replaybuffer(self):
    
            table_name = 'uniform_table'
            replay_buffer_signature = tensor_spec.from_spec(
                self.agent.collect_data_spec)
            replay_buffer_signature = tensor_spec.add_outer_dim(
                replay_buffer_signature)
    
            table = reverb.Table(table_name,
                                 max_size=self.replay_buffer_max_length,
                                 sampler=reverb.selectors.Uniform(),
                                 remover=reverb.selectors.Fifo(),
                                 rate_limiter=reverb.rate_limiters.MinSize(1),
                                 signature=replay_buffer_signature)
    
            reverb_server = reverb.Server([table])
    
            self.replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
                self.agent.collect_data_spec,
                table_name=table_name,
                sequence_length=2,
                local_server=reverb_server)
    
            self.rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
                self.replay_buffer.py_client,
                table_name,
                sequence_length=2)
    
            self.dataset = self.replay_buffer.as_dataset(num_parallel_calls=3,
                                                         sample_batch_size=self.batch_size,
                                                         num_steps=2).prefetch(3)
            self.iterator = iter(self.dataset)
    
        def testReplayBuffer(self):
            py_env = cGame()
            py_driver.PyDriver(
                py_env,
                py_tf_eager_policy.PyTFEagerPolicy(
                    self.random_policy,
                    use_tf_function=True),
                [self.rb_observer],
                max_steps=self.initial_collect_steps).run(self.train_env.reset())
    
        def trainAgent(self):
    
            self.returns = list()
            print(self.collect_policy)
            py_env = cGame()
            # Create a driver to collect experience.
            collect_driver = py_driver.PyDriver(
                py_env, # CHANGE 1
                py_tf_eager_policy.PyTFEagerPolicy(
                    self.agent.collect_policy,
                    # batch_time_steps=False, # CHANGE 2
                    use_tf_function=True),
                [self.rb_observer],
                max_steps=self.collect_steps_per_iteration)
    
            # Reset the environment.
            # time_step = self.train_env.reset()
            time_step = py_env.reset()
            for _ in range(self.num_iterations):
    
                # Collect a few steps and save to the replay buffer.
                time_step, _ = collect_driver.run(time_step)
    
                # Sample a batch of data from the buffer and update the agent's network.
                experience, unused_info = next(self.iterator)
                train_loss = self.agent.train(experience).loss
    
                step = self.agent.train_step_counter.numpy()
    
                if step % self.log_interval == 0:
                    print('step = {0}: loss = {1}'.format(step, train_loss))
    
                if step % self.eval_interval == 0:
                    avg_return = self.compute_avg_return(self.eval_env,
                                                         self.agent.policy,
                                                         self.num_eval_episodes)
                    print(
                        'step = {0}: Average Return = {1}'.format(step, avg_return))
                    self.returns.append(avg_return)
    
        def run(self):
            self.createAgent()
            # self.compute_avg_return(self.train_env,self.eval_policy)
            self.create_replaybuffer()
            # self.testReplayBuffer()
            self.trainAgent()
            return True
    
    if __name__ == '__main__':
        mT = mTrainer()
        mT.run()

`

Upvotes: 0

Related Questions