Reputation: 11
I'm trying to adapt the TensorFlow Agents tutorial to a custom environment. It's not very complicated and meant to teach me how this works. The game is basically a 21x21 grid with tokens the agent can collect for a reward by walking around. I can validate the environment, the agent, and the replay buffer, but when I try to train the model, i get an error message (see bottom). Any advice would be welcome !
The agent class is:
import numpy as np
import random
from IPython.display import clear_output
import time
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
class cGame (py_environment.PyEnvironment):
def __init__(self):
self.xdim = 21
self.ydim = 21
self.mmap = np.array([[0]*self.xdim]*self.ydim)
self._turnNumber = 0
self.playerPos = {"x":1, "y":1}
self.totalScore = 0
self.reward = 0.0
self.input = 0
self.addRewardEveryNTurns = 4
self.addBombEveryNTurns = 3
self._episode_ended = False
## player = 13
## bomb = 14
self._action_spec = array_spec.BoundedArraySpec(shape=(), dtype=np.int32, minimum=0, maximum=3, name='action')
self._observation_spec = array_spec.BoundedArraySpec(shape = (441,), minimum=np.array([-1]*441), maximum = np.array([20]*441), dtype=np.int32, name='observation') #(self.xdim, self.ydim) , self.mmap.shape, minimum = -1, maximum = 10
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def addMapReward(self):
dx = random.randint(1, self.xdim-2)
dy = random.randint(1, self.ydim-2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = random.randint(1, 9)
return True
def addBombToMap(self):
dx = random.randint(1, self.xdim-2)
dy = random.randint(1, self.ydim-2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = 14
return True
def _reset (self):
self.mmap = np.array([[0]*self.xdim]*self.ydim)
for y in range(self.ydim):
self.mmap[y][0] = -1
self.mmap[y][self.ydim-1] = -1
for x in range(self.xdim):
self.mmap[0][x] = -1
self.mmap[self.ydim-1][x] = -1
self.playerPos["x"] = random.randint(1, self.xdim-2)
self.playerPos["y"] = random.randint(1, self.ydim-2)
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
for z in range(10):
## place 10 targets
self.addMapReward()
for z in range(5):
## place 5 bombs
## bomb = 14
self.addBombToMap()
self._turnNumber = 0
self._episode_ended = False
#return ts.restart (self.mmap)
dap = ts.restart(np.array(self.mmap, dtype=np.int32).flatten())
return (dap)
def render(self, mapToRender):
mapToRender.reshape(21,21)
for y in range(self.ydim):
o =""
for x in range(self.xdim):
if mapToRender[y][x]==-1:
o=o+"#"
elif mapToRender[y][x]>0 and mapToRender[y][x]<10:
o=o+str(mapToRender[y][x])
elif mapToRender[y][x] == 13:
o=o+"@"
elif mapToRender[y][x] == 14:
o=o+"*"
else:
o=o+" "
print (o)
print ('TOTAL SCORE:', self.totalScore, 'LAST TURN SCORE:', self.reward)
return True
def getInput(self):
self.input = 0
i = input()
if i == 'w' or i == '0':
print ('going N')
self.input = 1
if i == 's' or i == '1':
print ('going S')
self.input = 2
if i == 'a' or i == '2':
print ('going W')
self.input = 3
if i == 'd' or i == '3':
print ('going E')
self.input = 4
if i == 'x':
self.input = 5
return self.input
def processMove(self):
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.reward = 0
if self.input == 0:
self.playerPos["y"] -=1
if self.input == 1:
self.playerPos["y"] +=1
if self.input == 2:
self.playerPos["x"] -=1
if self.input == 3:
self.playerPos["x"] +=1
cloc = self.mmap[self.playerPos["y"]][self.playerPos["x"]]
if cloc == -1 or cloc ==14:
self.totalScore = 0
self.reward = -99
if cloc >0 and cloc < 10:
self.totalScore += cloc
self.reward = cloc
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
self.render(self.mmap)
def runTurn(self):
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
self.getInput()
self.processMove()
self._turnNumber +=1
if self.reward == -99:
self._turnNumber +=1
self._reset()
self.totalScore = 0
self.render(self.mmap)
return (self.reward)
def _step (self, action):
if self._episode_ended == True:
return self._reset()
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
## make sure action does produce exceed range
#if action > 5 or action <1:
# action =0
self.input = action ## value 1 to 4
self.processMove()
self._turnNumber +=1
if self.reward == -99:
self._turnNumber +=1
self._episode_ended = True
#self._reset()
self.totalScore = 0
self.render(self.mmap)
return ts.termination(np.array(self.mmap, dtype=np.int32).flatten(), reward = self.reward)
else:
return ts.transition(np.array(self.mmap, dtype=np.int32).flatten(), reward = self.reward) #, discount = 1.0
def run (self):
self._reset()
self.render(self.mmap)
while (True):
self.runTurn()
if self.input == 5:
return ("EXIT on input x ")
env = cGame()
The class I want to use for training the model is:
from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
import reverb
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.drivers import py_driver
from tf_agents.environments import BatchedPyEnvironment
class mTrainer:
def __init__ (self):
self.train_env = tf_py_environment.TFPyEnvironment(cGame())
self.eval_env = tf_py_environment.TFPyEnvironment(cGame())
self.num_iterations = 20000 # @param {type:"integer"}
self.initial_collect_steps = 100 # @param {type:"integer"}
self.collect_steps_per_iteration = 100 # @param {type:"integer"}
self.replay_buffer_max_length = 100000 # @param {type:"integer"}
self.batch_size = 64 # @param {type:"integer"}
self.learning_rate = 1e-3 # @param {type:"number"}
self.log_interval = 200 # @param {type:"integer"}
self.num_eval_episodes = 10 # @param {type:"integer"}
self.eval_interval = 1000 # @param {type:"integer"}
def createAgent(self):
fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(self.train_env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
def dense_layer(num_units):
return tf.keras.layers.Dense(
num_units,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_in', distribution='truncated_normal'))
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.03, maxval=0.03),
bias_initializer=tf.keras.initializers.Constant(-0.2))
self.q_net = sequential.Sequential(dense_layers + [q_values_layer])
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
#rain_step_counter = tf.Variable(0)
self.agent = dqn_agent.DqnAgent(
time_step_spec = self.train_env.time_step_spec(),
action_spec = self.train_env.action_spec(),
q_network=self.q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=tf.Variable(0))
self.agent.initialize()
self.eval_policy = self.agent.policy
self.collect_policy = self.agent.collect_policy
self.random_policy = random_tf_policy.RandomTFPolicy(self.train_env.time_step_spec(),self.train_env.action_spec())
return True
def compute_avg_return(self, environment, policy, num_episodes=10):
#mT.compute_avg_return(mT.eval_env, mT.random_policy, 50)
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
print ('average return :', avg_return.numpy()[0])
return avg_return.numpy()[0]
def create_replaybuffer(self):
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(self.agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
table = reverb.Table(table_name,
max_size=self.replay_buffer_max_length,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)
reverb_server = reverb.Server([table])
self.replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
self.agent.collect_data_spec,
table_name=table_name,
sequence_length=2,
local_server=reverb_server)
self.rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
self.replay_buffer.py_client,
table_name,
sequence_length=2)
self.dataset = self.replay_buffer.as_dataset(num_parallel_calls=3,sample_batch_size=self.batch_size,num_steps=2).prefetch(3)
self.iterator = iter(self.dataset)
def testReplayBuffer(self):
py_driver.PyDriver(
self.train_env,
py_tf_eager_policy.PyTFEagerPolicy(
self.random_policy,
use_tf_function=True),
[self.rb_observer],
max_steps=self.initial_collect_steps).run(self.train_env.reset())
def trainAgent(self):
print (self.collect_policy)
# Create a driver to collect experience.
collect_driver = py_driver.PyDriver(
self.train_env,
py_tf_eager_policy.PyTFEagerPolicy(
self.agent.collect_policy,
batch_time_steps=False,
use_tf_function=True),
[self.rb_observer],
max_steps=self.collect_steps_per_iteration)
# Reset the environment.
time_step = self.train_env.reset()
for _ in range(self.num_iterations):
# Collect a few steps and save to the replay buffer.
time_step, _ = collect_driver.run(time_step)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(self.iterator)
train_loss = agent.train(experience).loss
step = agent.train_step_counter.numpy()
if step % log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))
if step % eval_interval == 0:
avg_return = self.compute_avg_return(self.eval_env, agent.policy, num_eval_episodes)
print('step = {0}: Average Return = {1}'.format(step, avg_return))
self.returns.append(avg_return)
def run(self):
self.createAgent()
#self.compute_avg_return(self.train_env,self.eval_policy)
self.create_replaybuffer()
#self.testReplayBuffer()
self.trainAgent()
return True
mT = mTrainer()
mT.run()
It produces this error message:
InvalidArgumentError: Received incompatible tensor at flattened index 4 from table 'uniform_table'. Specification has (dtype, shape): (int32, [?]). Tensor has (dtype, shape): (int32, [2,1]). Table signature: 0: Tensor<name: 'key', dtype: uint64, shape: []>, 1: Tensor<name: 'probability', dtype: double, shape: []>, 2: Tensor<name: 'table_size', dtype: int64, shape: []>, 3: Tensor<name: 'priority', dtype: double, shape: []>, 4: Tensor<name: 'step_type/step_type', dtype: int32, shape: [?]>, 5: Tensor<name: 'observation/observation', dtype: int32, shape: [?,441]>, 6: Tensor<name: 'action/action', dtype: int32, shape: [?]>, 7: Tensor<name: 'next_step_type/step_type', dtype: int32, shape: [?]>, 8: Tensor<name: 'reward/reward', dtype: float, shape: [?]>, 9: Tensor<name: 'discount/discount', dtype: float, shape: [?]> [Op:IteratorGetNext]
Upvotes: 1
Views: 770
Reputation: 126
I got stuck with a similar issue, the reason for it is that, you are using tensorflow environment as the parameter of the PyDriver
to collect the data. Tensorflow environment adds a batch dimension to all the tensors that it produces, therefore, each time_step
generated will have an additional dimension whose value will be 1.
Now, when you retrieve the data from the replay buffer, each of time_step
will have an additional dimension and it is not compatible with the data that the train function of the agent is expecting, hence the error.
You need to use a python environment here in order to collect the data with right dimension. Also, now you don't have to use batch_time_steps = False
.
I am not sure how to collect the data with right dimensions with a tensorflow environment so I have modified your code a bit to allow data collection using python environment and it should run now.
PS - There were a few trivial bugs in the code you posted (ex. using log_interval
instead of self.log_interval
etc).
Agent Class `
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import random
from IPython.display import clear_output
import time
import abc
import tensorflow as tf
import numpy as np
from tf_agents.environments import py_environment
from tf_agents.environments import tf_environment
from tf_agents.environments import tf_py_environment
from tf_agents.environments import utils
from tf_agents.specs import array_spec
from tf_agents.environments import wrappers
from tf_agents.environments import suite_gym
from tf_agents.trajectories import time_step as ts
class cGame(py_environment.PyEnvironment):
def __init__(self):
self.xdim = 21
self.ydim = 21
self.mmap = np.array([[0] * self.xdim] * self.ydim)
self._turnNumber = 0
self.playerPos = {"x": 1, "y": 1}
self.totalScore = 0
self.reward = 0.0
self.input = 0
self.addRewardEveryNTurns = 4
self.addBombEveryNTurns = 3
self._episode_ended = False
## player = 13
## bomb = 14
self._action_spec = array_spec.BoundedArraySpec(shape=(),
dtype=np.int32,
minimum=0, maximum=3,
name='action')
self._observation_spec = array_spec.BoundedArraySpec(shape=(441,),
minimum=np.array(
[-1] * 441),
maximum=np.array(
[20] * 441),
dtype=np.int32,
name='observation') # (self.xdim, self.ydim) , self.mmap.shape, minimum = -1, maximum = 10
def action_spec(self):
return self._action_spec
def observation_spec(self):
return self._observation_spec
def addMapReward(self):
dx = random.randint(1, self.xdim - 2)
dy = random.randint(1, self.ydim - 2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = random.randint(1, 9)
return True
def addBombToMap(self):
dx = random.randint(1, self.xdim - 2)
dy = random.randint(1, self.ydim - 2)
if dx != self.playerPos["x"] and dy != self.playerPos["y"]:
self.mmap[dy][dx] = 14
return True
def _reset(self):
self.mmap = np.array([[0] * self.xdim] * self.ydim)
for y in range(self.ydim):
self.mmap[y][0] = -1
self.mmap[y][self.ydim - 1] = -1
for x in range(self.xdim):
self.mmap[0][x] = -1
self.mmap[self.ydim - 1][x] = -1
self.playerPos["x"] = random.randint(1, self.xdim - 2)
self.playerPos["y"] = random.randint(1, self.ydim - 2)
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
for z in range(10):
## place 10 targets
self.addMapReward()
for z in range(5):
## place 5 bombs
## bomb = 14
self.addBombToMap()
self._turnNumber = 0
self._episode_ended = False
# return ts.restart (self.mmap)
dap = ts.restart(np.array(self.mmap, dtype=np.int32).flatten())
return (dap)
def render(self, mapToRender):
mapToRender.reshape(21, 21)
for y in range(self.ydim):
o = ""
for x in range(self.xdim):
if mapToRender[y][x] == -1:
o = o + "#"
elif mapToRender[y][x] > 0 and mapToRender[y][x] < 10:
o = o + str(mapToRender[y][x])
elif mapToRender[y][x] == 13:
o = o + "@"
elif mapToRender[y][x] == 14:
o = o + "*"
else:
o = o + " "
print(o)
print('TOTAL SCORE:', self.totalScore, 'LAST TURN SCORE:', self.reward)
return True
def getInput(self):
self.input = 0
i = input()
if i == 'w' or i == '0':
print('going N')
self.input = 1
if i == 's' or i == '1':
print('going S')
self.input = 2
if i == 'a' or i == '2':
print('going W')
self.input = 3
if i == 'd' or i == '3':
print('going E')
self.input = 4
if i == 'x':
self.input = 5
return self.input
def processMove(self):
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.reward = 0
if self.input == 0:
self.playerPos["y"] -= 1
if self.input == 1:
self.playerPos["y"] += 1
if self.input == 2:
self.playerPos["x"] -= 1
if self.input == 3:
self.playerPos["x"] += 1
cloc = self.mmap[self.playerPos["y"]][self.playerPos["x"]]
if cloc == -1 or cloc == 14:
self.totalScore = 0
self.reward = -99
if cloc > 0 and cloc < 10:
self.totalScore += cloc
self.reward = cloc
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 0
self.mmap[self.playerPos["y"]][self.playerPos["x"]] = 13
self.render(self.mmap)
def runTurn(self):
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
self.getInput()
self.processMove()
self._turnNumber += 1
if self.reward == -99:
self._turnNumber += 1
self._reset()
self.totalScore = 0
self.render(self.mmap)
return (self.reward)
def _step(self, action):
if self._episode_ended == True:
return self._reset()
clear_output(wait=True)
if self._turnNumber % self.addRewardEveryNTurns == 0:
self.addMapReward()
if self._turnNumber % self.addBombEveryNTurns == 0:
self.addBombToMap()
## make sure action does produce exceed range
# if action > 5 or action <1:
# action =0
self.input = action ## value 1 to 4
self.processMove()
self._turnNumber += 1
if self.reward == -99:
self._turnNumber += 1
self._episode_ended = True
# self._reset()
self.totalScore = 0
self.render(self.mmap)
return ts.termination(np.array(self.mmap, dtype=np.int32).flatten(),
reward=self.reward)
else:
return ts.transition(np.array(self.mmap, dtype=np.int32).flatten(),
reward=self.reward) # , discount = 1.0
def run(self):
self._reset()
self.render(self.mmap)
while (True):
self.runTurn()
if self.input == 5:
return ("EXIT on input x ")
env = cGame()
`
Driver Code `
from tf_agents.specs import tensor_spec
from tf_agents.networks import sequential
from tf_agents.agents.dqn import dqn_agent
from tf_agents.utils import common
from tf_agents.policies import py_tf_eager_policy
from tf_agents.policies import random_tf_policy
import reverb
from tf_agents.replay_buffers import reverb_replay_buffer
from tf_agents.replay_buffers import reverb_utils
from tf_agents.trajectories import trajectory
from tf_agents.drivers import py_driver
from tf_agents.environments import BatchedPyEnvironment
class mTrainer:
def __init__(self):
self.returns = None
self.train_env = tf_py_environment.TFPyEnvironment(cGame())
self.eval_env = tf_py_environment.TFPyEnvironment(cGame())
self.num_iterations = 20000 # @param {type:"integer"}
self.initial_collect_steps = 100 # @param {type:"integer"}
self.collect_steps_per_iteration = 100 # @param {type:"integer"}
self.replay_buffer_max_length = 100000 # @param {type:"integer"}
self.batch_size = 64 # @param {type:"integer"}
self.learning_rate = 1e-3 # @param {type:"number"}
self.log_interval = 200 # @param {type:"integer"}
self.num_eval_episodes = 10 # @param {type:"integer"}
self.eval_interval = 1000 # @param {type:"integer"}
def createAgent(self):
fc_layer_params = (100, 50)
action_tensor_spec = tensor_spec.from_spec(self.train_env.action_spec())
num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
def dense_layer(num_units):
return tf.keras.layers.Dense(
num_units,
activation=tf.keras.activations.relu,
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=2.0, mode='fan_in', distribution='truncated_normal'))
dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
q_values_layer = tf.keras.layers.Dense(
num_actions,
activation=None,
kernel_initializer=tf.keras.initializers.RandomUniform(
minval=-0.03, maxval=0.03),
bias_initializer=tf.keras.initializers.Constant(-0.2))
self.q_net = sequential.Sequential(dense_layers + [q_values_layer])
optimizer = tf.keras.optimizers.Adam(learning_rate=self.learning_rate)
# rain_step_counter = tf.Variable(0)
self.agent = dqn_agent.DqnAgent(
time_step_spec=self.train_env.time_step_spec(),
action_spec=self.train_env.action_spec(),
q_network=self.q_net,
optimizer=optimizer,
td_errors_loss_fn=common.element_wise_squared_loss,
train_step_counter=tf.Variable(0))
self.agent.initialize()
self.eval_policy = self.agent.policy
self.collect_policy = self.agent.collect_policy
self.random_policy = random_tf_policy.RandomTFPolicy(
self.train_env.time_step_spec(), self.train_env.action_spec())
return True
def compute_avg_return(self, environment, policy, num_episodes=10):
# mT.compute_avg_return(mT.eval_env, mT.random_policy, 50)
total_return = 0.0
for _ in range(num_episodes):
time_step = environment.reset()
episode_return = 0.0
while not time_step.is_last():
action_step = policy.action(time_step)
time_step = environment.step(action_step.action)
episode_return += time_step.reward
total_return += episode_return
avg_return = total_return / num_episodes
print('average return :', avg_return.numpy()[0])
return avg_return.numpy()[0]
def create_replaybuffer(self):
table_name = 'uniform_table'
replay_buffer_signature = tensor_spec.from_spec(
self.agent.collect_data_spec)
replay_buffer_signature = tensor_spec.add_outer_dim(
replay_buffer_signature)
table = reverb.Table(table_name,
max_size=self.replay_buffer_max_length,
sampler=reverb.selectors.Uniform(),
remover=reverb.selectors.Fifo(),
rate_limiter=reverb.rate_limiters.MinSize(1),
signature=replay_buffer_signature)
reverb_server = reverb.Server([table])
self.replay_buffer = reverb_replay_buffer.ReverbReplayBuffer(
self.agent.collect_data_spec,
table_name=table_name,
sequence_length=2,
local_server=reverb_server)
self.rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
self.replay_buffer.py_client,
table_name,
sequence_length=2)
self.dataset = self.replay_buffer.as_dataset(num_parallel_calls=3,
sample_batch_size=self.batch_size,
num_steps=2).prefetch(3)
self.iterator = iter(self.dataset)
def testReplayBuffer(self):
py_env = cGame()
py_driver.PyDriver(
py_env,
py_tf_eager_policy.PyTFEagerPolicy(
self.random_policy,
use_tf_function=True),
[self.rb_observer],
max_steps=self.initial_collect_steps).run(self.train_env.reset())
def trainAgent(self):
self.returns = list()
print(self.collect_policy)
py_env = cGame()
# Create a driver to collect experience.
collect_driver = py_driver.PyDriver(
py_env, # CHANGE 1
py_tf_eager_policy.PyTFEagerPolicy(
self.agent.collect_policy,
# batch_time_steps=False, # CHANGE 2
use_tf_function=True),
[self.rb_observer],
max_steps=self.collect_steps_per_iteration)
# Reset the environment.
# time_step = self.train_env.reset()
time_step = py_env.reset()
for _ in range(self.num_iterations):
# Collect a few steps and save to the replay buffer.
time_step, _ = collect_driver.run(time_step)
# Sample a batch of data from the buffer and update the agent's network.
experience, unused_info = next(self.iterator)
train_loss = self.agent.train(experience).loss
step = self.agent.train_step_counter.numpy()
if step % self.log_interval == 0:
print('step = {0}: loss = {1}'.format(step, train_loss))
if step % self.eval_interval == 0:
avg_return = self.compute_avg_return(self.eval_env,
self.agent.policy,
self.num_eval_episodes)
print(
'step = {0}: Average Return = {1}'.format(step, avg_return))
self.returns.append(avg_return)
def run(self):
self.createAgent()
# self.compute_avg_return(self.train_env,self.eval_policy)
self.create_replaybuffer()
# self.testReplayBuffer()
self.trainAgent()
return True
if __name__ == '__main__':
mT = mTrainer()
mT.run()
`
Upvotes: 0