Charlie Hou
Charlie Hou

Reputation: 103

How do we print action distributions in RLlib during training?

I'm trying to print action distributions at the end of each episode to see what my agent is doing. I've attempted to do put this is rock_paper_scissors_multiagent.py by including the following method

def on_episode_end(info):
    episode = info["episode"]
    policy = episode._policies['learned']
    print(policy.model.base_model.summary())

    obs_space = Tuple((Discrete(3),Discrete(3)))
    prep = get_preprocessor(obs_space)(obs_space)
    curr_state = list((0,1))
    curr_state = tuple(curr_state)
    curr_state = prep.transform(curr_state)
    logits, _ = policy.model.from_batch({"obs": np.array([curr_state])})
    dist = policy.dist_class(logits, policy.model)
    dist.sample()
    print(dist.logp([0]))

And adding the callback option to tune.run. However, I get the following error. Is this how I should be trying to print the policies after each episode? And if so, what am I doing wrong? The rock_paper_scissors_multiagent.py original example is here: https://github.com/ray-project/ray/blob/master/rllib/examples/rock_paper_scissors_multiagent.py

>Traceback (most recent call last):
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/tune/trial_runner.py", line 515, in _process_trial
    result = self.trial_executor.fetch_result(trial)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/tune/ray_trial_executor.py", line 351, in fetch_result
    result = ray.get(trial_future[0])
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/worker.py", line 2121, in get
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ValueError): [36mray_worker[39m (pid=5765, host=Charlies-MBP.fios-router.home)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 418, in train
    raise e
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/agents/trainer.py", line 407, in train
    result = Trainable.train(self)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/tune/trainable.py", line 176, in train
    result = self._train()
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/agents/trainer_template.py", line 129, in _train
    fetches = self.optimizer.step()
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/optimizers/multi_gpu_optimizer.py", line 140, in step
    self.num_envs_per_worker, self.train_batch_size)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/optimizers/rollout.py", line 29, in collect_samples
    next_sample = ray_get_and_free(fut_sample)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/utils/memory.py", line 33, in ray_get_and_free
    result = ray.get(object_ids)
ray.exceptions.RayTaskError(ValueError): [36mray_worker[39m (pid=5768, host=Charlies-MBP.fios-router.home)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/rollout_worker.py", line 469, in sample
    batches = [self.input_reader.next()]
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 56, in next
    batches = [self.get_data()]
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 99, in get_data
    item = next(self.rollout_provider)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 319, in _env_runner
    soft_horizon, no_done_at_end)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/evaluation/sampler.py", line 473, in _process_observations
    "episode": episode
  File "rock_paper_scissors_multiagent.py", line 204, in on_episode_end
    logits, _ = policy.model.from_batch({"obs": np.array([curr_state])})
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/models/modelv2.py", line 197, in from_batch
    return self.__call__(input_dict, states, train_batch.get("seq_lens"))
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/models/modelv2.py", line 154, in __call__
    res = self.forward(restored, state or [], seq_lens)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/ray/rllib/models/tf/fcnet_v2.py", line 84, in forward
    model_out, self._value_out = self.base_model(input_dict["obs_flat"])
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 634, in __call__
    outputs = call_fn(inputs, *args, **kwargs)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 751, in call
    return self._run_internal_graph(inputs, training=training, mask=mask)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/network.py", line 893, in _run_internal_graph
    output_tensors = layer(computed_tensors, **kwargs)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/base_layer.py", line 586, in __call__
    self.name)
  File "/Users/charliehou/anaconda3/lib/python3.6/site-packages/tensorflow/python/keras/engine/input_spec.py", line 159, in assert_input_compatibility
    ' but received input with shape ' + str(shape))
ValueError: Input 0 of layer fc_value_1 is incompatible with the layer: expected axis -1 of input shape to have value 3 but received input with shape [1, 6]  

Upvotes: 2

Views: 2683

Answers (1)

Huan
Huan

Reputation: 445

You can access the action dictionary in the sample_batch object, sample_obj.columns(["actions"] within the on_postprocess_traj callback define as follows:

# The callback function

def on_postprocess_traj(info):
    """
    arg: {"agent_id": ..., "episode": ...,
          "pre_batch": (before processing),
          "post_batch": (after processing),
          "all_pre_batches": (other agent ids),
          }

    # https://github.com/ray-project/ray/blob/ee8c9ff7320ec6a2d7d097cd5532005c6aeb216e/rllib/policy/sample_batch.py
    Dictionaries in a sample_obj, k:
        t
        eps_id
        agent_index
        obs
        actions
        rewards
        prev_actions
        prev_rewards
        dones
        infos
        new_obs
        action_prob
        action_logp
        vf_preds
        behaviour_logits
        unroll_id       
    """
    agt_id = info["agent_id"]
    eps_id = info["episode"].episode_id
    policy_obj = info["pre_batch"][0]
    sample_obj = info["pre_batch"][1]    

    if(agt_id == 'player1'):
        print('agent_id = {}'.format(agt_id))
        print('episode = {}'.format(eps_id))

        #print("on_postprocess_traj info = {}".format(info))
        #print("on_postprocess_traj sample_obj = {}".format(sample_obj))
        print('actions = {}'.format(sample_obj.columns(["actions"])))
    return

You will need to also add the callback function to your config like this:

             config={"env": RockPaperScissorsEnv,
                     #"eager": True,
                     "gamma": 0.9,
                     "num_workers": 1,
                     "num_envs_per_worker": 4,
                     "sample_batch_size": 10,
                     "train_batch_size": 200,
                     #"multiagent": {"policies_to_train": ["learned"],
                     "multiagent": {"policies_to_train": ["learned", "learned_2"],
                                    "policies": {"always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3), {}),
                                                 #"beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
                                                 "learned": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 "learned_2": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 },
                                    "policy_mapping_fn": select_policy,
                                   },
                      "callbacks": {#"on_episode_start": on_episode_start, 
                                    #"on_episode_step": on_episode_step, 
                                    #"on_episode_end": on_episode_end, 
                                    #"on_sample_end": on_sample_end,
                                    "on_postprocess_traj": on_postprocess_traj,
                                    #"on_train_result": on_train_result,
                                    }

The result below shows the output from running the rock_paper_scissors_multiagent.py example (with ray[rllib]==0.8.2 in Colab), notice the print out of the agent ID, episode ID & the action trajectory:

== Status ==
Memory usage on this node: 1.3/12.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 2/2 CPUs, 0/0 GPUs, 0.0/7.18 GiB heap, 0.0/2.44 GiB objects
Result logdir: /root/ray_results/PPO
Number of trials: 1 (1 RUNNING)
Trial name  status  loc
PPO_RockPaperScissorsEnv_979bff44   RUNNING 


(pid=1541) 2020-04-25 12:45:10,823  INFO trainer.py:420 -- Tip: set 'eager': true or the --eager flag to enable TensorFlow eager execution
(pid=1541) 2020-04-25 12:45:10,827  INFO trainer.py:580 -- Current log_level is WARN. For more information, set 'log_level': 'INFO' / 'DEBUG' or use the -v and -vv flags.
(pid=1541) /usr/local/lib/python3.6/dist-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
(pid=1541)   warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
(pid=1587) /usr/local/lib/python3.6/dist-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
(pid=1587)   warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
(pid=1541) 2020-04-25 12:45:19,048  WARNING util.py:37 -- Install gputil for GPU system monitoring.
(pid=1587) agent_id = player1
(pid=1587) episode = 975148816
(pid=1587) actions = [array([1, 1, 0, 2, 0, 0, 1, 2, 1, 2])]
(pid=1587) agent_id = player1
(pid=1587) episode = 942369634
(pid=1587) actions = [array([1, 2, 1, 2, 2, 2, 1, 0, 2, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 296105405
(pid=1587) actions = [array([2, 2, 0, 2, 2, 1, 2, 1, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 475466940
(pid=1587) actions = [array([0, 2, 1, 0, 2, 0, 2, 1, 0, 2])]
(pid=1587) agent_id = player1
(pid=1587) episode = 793839240
(pid=1587) actions = [array([0, 0, 1, 2, 0, 2, 1, 1, 1, 2])]
(pid=1587) agent_id = player1
(pid=1587) episode = 578652318
(pid=1587) actions = [array([0, 1, 0, 0, 2, 1, 2, 2, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 112165627
(pid=1587) actions = [array([2, 1, 2, 1, 0, 0, 0, 1, 1, 0]
(pid=1587) agent_id = player1
(pid=1587) episode = 996828544
(pid=1587) actions = [array([1, 2, 2, 2, 0, 0, 1, 2, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 94669775
(pid=1587) actions = [array([1, 0, 1, 1, 2, 0, 2, 1, 2, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1063457620
(pid=1587) actions = [array([1, 0, 2, 1, 2, 2, 1, 2, 2, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1956229719
(pid=1587) actions = [array([0, 0, 2, 1, 2, 2, 2, 1, 2, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 503578202
(pid=1587) actions = [array([1, 2, 0, 0, 0, 0, 1, 0, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1599756661
(pid=1587) actions = [array([0, 0, 1, 2, 0, 2, 2, 2, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1333277267
(pid=1587) actions = [array([0, 2, 1, 0, 1, 1, 2, 2, 2, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1832916757
(pid=1587) actions = [array([1, 1, 0, 0, 2, 1, 0, 1, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 585983090
(pid=1587) actions = [array([1, 2, 1, 2, 2, 1, 0, 2, 0, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 1731969708
(pid=1587) actions = [array([2, 1, 0, 2, 2, 0, 0, 0, 1, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 374111939
(pid=1587) actions = [array([0, 0, 0, 2, 0, 2, 2, 0, 1, 0])]
(pid=1587) agent_id = player1
(pid=1587) episode = 399432786
(pid=1587) actions = [array([0, 2, 0, 0, 0, 1, 0, 0, 1, 1])]
(pid=1587) agent_id = player1
(pid=1587) episode = 396598872
(pid=1587) actions = [array([1, 1, 0, 2, 0, 2, 0, 2, 1, 0])]
Result for PPO_RockPaperScissorsEnv_979bff44:
  custom_metrics: {}
  date: 2020-04-25_12-45-24
  done: true
  episode_len_mean: 10.0
  episode_reward_max: 0.0
  episode_reward_mean: 0.0
  episode_reward_min: 0.0
  episodes_this_iter: 20
  episodes_total: 20
  experiment_id: 87214df9c01d4efeae8edd4d656a6ca4
  experiment_tag: '0'
  hostname: 2ebf5ae102f8
  info:
    grad_time_ms: 1005.051
    learner:
      learned:
        cur_kl_coeff: 0.20000000298023224
        cur_lr: 4.999999873689376e-05
        entropy: 1.0945309400558472
        entropy_coeff: 0.0
        kl: 0.004110474139451981
        policy_loss: -0.0945899486541748
        total_loss: 2.941073417663574
        vf_explained_var: 0.00013327598571777344
        vf_loss: 3.034841299057007
      learned_2:
        cur_kl_coeff: 0.20000000298023224
        cur_lr: 4.999999873689376e-05
        entropy: 1.0941331386566162
        entropy_coeff: 0.0
        kl: 0.004472262226045132
        policy_loss: -0.0190987978130579
        total_loss: 3.0051088333129883
        vf_explained_var: 0.008207857608795166
        vf_loss: 3.023313045501709
    load_time_ms: 179.466
    num_steps_sampled: 200
    num_steps_trained: 128
    sample_time_ms: 343.341
    update_time_ms: 2861.349
  iterations_since_restore: 1
  node_ip: 172.28.0.2
  num_healthy_workers: 1
  off_policy_estimator: {}
  perf:
    cpu_util_percent: 85.65
    ram_util_percent: 16.225
  pid: 1541
  policy_reward_max:
    learned: 6.0
    learned_2: 6.0
  policy_reward_mean:
    learned: -0.15
    learned_2: 0.15
  policy_reward_min:
    learned: -6.0
    learned_2: -6.0
  sampler_perf:
    mean_env_wait_ms: 0.062040254181506584
    mean_inference_ms: 3.5300535314223347
    mean_processing_ms: 1.2217222475538068
  time_since_restore: 4.562142610549927
  time_this_iter_s: 4.562142610549927
  time_total_s: 4.562142610549927
  timestamp: 1587818724
  timesteps_since_restore: 200
  timesteps_this_iter: 200
  timesteps_total: 200
  training_iteration: 1
  trial_id: 979bff44

== Status ==
Memory usage on this node: 2.0/12.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/2 CPUs, 0/0 GPUs, 0.0/7.18 GiB heap, 0.0/2.44 GiB objects
Result logdir: /root/ray_results/PPO
Number of trials: 1 (1 TERMINATED)
Trial name  status  loc reward  total time (s)  ts  iter
PPO_RockPaperScissorsEnv_979bff44   TERMINATED      0   4.56214 200 1


== Status ==
Memory usage on this node: 1.9/12.7 GiB
Using FIFO scheduling algorithm.
Resources requested: 0/2 CPUs, 0/0 GPUs, 0.0/7.18 GiB heap, 0.0/2.44 GiB objects
Result logdir: /root/ray_results/PPO
Number of trials: 1 (1 TERMINATED)
Trial name  status  loc reward  total time (s)  ts  iter
PPO_RockPaperScissorsEnv_979bff44   TERMINATED      0   4.56214 200 1


2020-04-25 12:45:24,345 INFO tune.py:352 -- Returning an analysis object by default. You can call `analysis.trials` to retrieve a list of trials. This message will be removed in future versions of Tune.

Not only can you access actions but you should be able to access all useful pre/post batch (trajectory) info this way. Have a look at the comments I made in the callback function for a list of the available dictionary names (such as obs, rewards) that you may also find useful.

The complete rock_paper_scissors_multiagent.py example code that prints the above output is shown below:

#!pip install ray[rllib]==0.8.2

"""A simple multi-agent env with two agents playing rock paper scissors.
This demonstrates running the following policies in competition:
    (1) heuristic policy of repeating the same move
    (2) heuristic policy of beating the last opponent move
    (3) LSTM/feedforward PG policies
    (4) LSTM policy with custom entropy loss
"""

import argparse
import random
from gym.spaces import Discrete

from ray import tune
from ray.rllib.agents.pg.pg import PGTrainer
from ray.rllib.agents.pg.pg_tf_policy import PGTFPolicy
from ray.rllib.policy.policy import Policy
from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils import try_import_tf

tf = try_import_tf()

ROCK = 0
PAPER = 1
SCISSORS = 2

parser = argparse.ArgumentParser()
parser.add_argument("--stop", type=int, default=400000)

class RockPaperScissorsEnv(MultiAgentEnv):
    """Two-player environment for rock paper scissors.
    The observation is simply the last opponent action."""

    def __init__(self, _):
        self.action_space = Discrete(3)
        self.observation_space = Discrete(3)
        self.player1 = "player1"
        self.player2 = "player2"
        self.last_move = None
        self.num_moves = 0

    def reset(self):
        self.last_move = (0, 0)
        self.num_moves = 0
        return {
            self.player1: self.last_move[1],
            self.player2: self.last_move[0],
        }

    def step(self, action_dict):
        move1 = action_dict[self.player1]
        move2 = action_dict[self.player2]
        self.last_move = (move1, move2)
        obs = {
            self.player1: self.last_move[1],
            self.player2: self.last_move[0],
        }

        r1, r2 = {
            (ROCK, ROCK): (0, 0),
            (ROCK, PAPER): (-1, 1),
            (ROCK, SCISSORS): (1, -1),
            (PAPER, ROCK): (1, -1),
            (PAPER, PAPER): (0, 0),
            (PAPER, SCISSORS): (-1, 1),
            (SCISSORS, ROCK): (-1, 1),
            (SCISSORS, PAPER): (1, -1),
            (SCISSORS, SCISSORS): (0, 0),
        }[move1, move2]
        rew = {
            self.player1: r1,
            self.player2: r2,
        }
        self.num_moves += 1
        done = {
            "__all__": self.num_moves >= 10,
        }

        #print('obs', obs)

        return obs, rew, done, {}

class AlwaysSameHeuristic(Policy):
    """Pick a random move and stick with it for the entire episode."""

    def get_initial_state(self):
        return [random.choice([ROCK, PAPER, SCISSORS])]

    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        return list(state_batches[0]), state_batches, {}

    def learn_on_batch(self, samples):
        pass

    def get_weights(self):
        pass

    def set_weights(self, weights):
        pass

class BeatLastHeuristic(Policy):
    """Play the move that would beat the last move of the opponent."""

    def compute_actions(self,
                        obs_batch,
                        state_batches=None,
                        prev_action_batch=None,
                        prev_reward_batch=None,
                        info_batch=None,
                        episodes=None,
                        **kwargs):
        def successor(x):
            if x[ROCK] == 1:
                return PAPER
            elif x[PAPER] == 1:
                return SCISSORS
            elif x[SCISSORS] == 1:
                return ROCK

        return [successor(x) for x in obs_batch], [], {}

    def learn_on_batch(self, samples):
        pass

    def get_weights(self):
        pass

    def set_weights(self, weights):
        pass

def on_postprocess_traj(info):
    """
    arg: {"agent_id": ..., "episode": ...,
          "pre_batch": (before processing),
          "post_batch": (after processing),
          "all_pre_batches": (other agent ids),
          }

    # https://github.com/ray-project/ray/blob/ee8c9ff7320ec6a2d7d097cd5532005c6aeb216e/rllib/policy/sample_batch.py
    Dictionaries in a sample_obj, k:
        t
        eps_id
        agent_index
        obs
        actions
        rewards
        prev_actions
        prev_rewards
        dones
        infos
        new_obs
        action_prob
        action_logp
        vf_preds
        behaviour_logits
        unroll_id
    """
    agt_id = info["agent_id"]
    eps_id = info["episode"].episode_id
    policy_obj = info["pre_batch"][0]
    sample_obj = info["pre_batch"][1]

    if(agt_id == 'player1'):
        print('agent_id = {}'.format(agt_id))
        print('episode = {}'.format(eps_id))

        #print("on_postprocess_traj info = {}".format(info))
        #print("on_postprocess_traj sample_obj = {}".format(sample_obj))
        print('actions = {}'.format(sample_obj.columns(["actions"])))
    return

def run_same_policy():
    """Use the same policy for both agents (trivial case)."""

    #tune.run("PG", config={"env": RockPaperScissorsEnv})
    tune.run("PPO", config={"env": RockPaperScissorsEnv})

#def run_heuristic_vs_learned(use_lstm=False, trainer="PG"):
def run_heuristic_vs_learned(use_lstm=False, trainer="PPO"):
    """Run heuristic policies vs a learned agent.
    The learned agent should eventually reach a reward of ~5 with
    use_lstm=False, and ~7 with use_lstm=True. The reason the LSTM policy
    can perform better is since it can distinguish between the always_same vs
    beat_last heuristics.
    """

    def select_policy(agent_id):
        if agent_id == "player1":
            return "learned"

        elif agent_id == "player2":
            return "learned_2"

        else:
            return random.choice(["always_same", "beat_last"])

    #args = parser.parse_args()
    tune.run(trainer,
             #stop={"timesteps_total": args.stop},
             #stop={"timesteps_total": 400000},
             stop={"timesteps_total": 3},

             config={"env": RockPaperScissorsEnv,
                     #"eager": True,
                     "gamma": 0.9,
                     "num_workers": 1,
                     "num_envs_per_worker": 4,
                     "sample_batch_size": 10,
                     "train_batch_size": 200,
                     #"multiagent": {"policies_to_train": ["learned"],
                     "multiagent": {"policies_to_train": ["learned", "learned_2"],
                                    "policies": {"always_same": (AlwaysSameHeuristic, Discrete(3), Discrete(3), {}),
                                                 #"beat_last": (BeatLastHeuristic, Discrete(3), Discrete(3), {}),
                                                 "learned": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 "learned_2": (None, Discrete(3), Discrete(3), {"model": {"use_lstm": use_lstm}}),
                                                 },
                                    "policy_mapping_fn": select_policy,
                                   },
                      "callbacks": {#"on_episode_start": on_episode_start,
                                    #"on_episode_step": on_episode_step,
                                    #"on_episode_end": on_episode_end,
                                    #"on_sample_end": on_sample_end,
                                    "on_postprocess_traj": on_postprocess_traj,
                                    #"on_train_result": on_train_result,
                                    }
                    }
             )

def run_with_custom_entropy_loss():
    """Example of customizing the loss function of an existing policy.
    This performs about the same as the default loss does."""

    def entropy_policy_gradient_loss(policy, model, dist_class, train_batch):
        logits, _ = model.from_batch(train_batch)
        action_dist = dist_class(logits, model)
        return (-0.1 * action_dist.entropy() - tf.reduce_mean(
            action_dist.logp(train_batch["actions"]) *
            train_batch["advantages"]))

    EntropyPolicy = PGTFPolicy.with_updates(
        loss_fn=entropy_policy_gradient_loss)
    EntropyLossPG = PGTrainer.with_updates(
        name="EntropyPG", get_policy_class=lambda _: EntropyPolicy)
    run_heuristic_vs_learned(use_lstm=True, trainer=EntropyLossPG)

'''
if __name__ == "__main__":
    # run_same_policy()
    # run_heuristic_vs_learned(use_lstm=False)
    run_heuristic_vs_learned(use_lstm=False)
    # run_with_custom_entropy_loss()
'''
#run_same_policy()
run_heuristic_vs_learned(use_lstm=False)
#run_with_custom_entropy_loss()

Upvotes: 1

Related Questions