Eric Monlye
Eric Monlye

Reputation: 121

DQN learns extremely slow on cliff walking

0. Some background

I was trying to train a DQN agent for a task with very large observation space (a 3D task for 100 * 100 * 30 * 36 * 18 states), and I found it not converging. After debugging for quite a long time, I noticed that the agent was still walking randomly after thousands of episodes. So I tried for a 2D simple environment to see if anything went wrong with my algorithm, and here is my settings and results.

1. The task settings

I was trying to train a vanilla DQN agent for cliff walking task. The environment is simply a 3 * 3 grid, with DEAD for cliff, DEST for destination, and FROM for the starting point.

# configs
episodes = 1000
steps = 200

FREE = 0
FROM = 1
DEAD = 2
DEST = 3

cliff_env = [
    [FREE, FREE, FREE],
    [FREE, FREE, FREE],
    [FROM, DEAD, DEST],
]

env = dict(
    type="CliffWalkDQNEnv",
    env=cliff_env,
    device="cuda",
)
agent= dict(
    type="CliffWalkDQN",
    device="cuda",
    gamma=0.98,
    qnet_cfg=dict(
        device="cuda",
        state_dim=9,
        hidden_dim=128,
        action_dim=4,
    ),
    lr=1e-3,
    explorer_cfg=dict(
        type="BaseEpsilonGreedyExplorer",
        samples=4,
        epsilon=0.999,
        epsilon_decay=0.994,
        epsilon_min=0.05,
    ),
)

replay_buffer = dict(
    type="CliffWalkReplayBuffer",
    device="cuda",
    batch_size=100,
    capacity=5000,
    activate_size=500,
)

I encode the states(observations) as a one-hot vector for qnet input. The qnet is a 3-layer fully connected network. And the agent updates target_qnet params every 10 episodes.

class QNet(Module):
    def __init__(self, device, state_dim, hidden_dim, action_dim):
        super().__init__()
        self.device = torch.device(device)

        self.linear = Linear(state_dim, hidden_dim, device=self.device)
        self.mid = Linear(hidden_dim, hidden_dim, device=self.device)
        self.out = Linear(hidden_dim, action_dim, device=self.device)

    def forward(self, input: Tensor):
        obs_indices = input.type(torch.int64)
        x = F.one_hot(
            obs_indices, num_classes=self.linear.in_features
        ).type(torch.float).view(-1, self.linear.in_features)

        x = F.relu(self.linear(x))
        x = F.relu(self.mid(x))

        out = self.out(x)

        return out


@EXPLORERS.register_module()
class BaseEpsilonGreedyExplorer(BaseExplorer):
    def __init__(
        self,
        samples: int = 2,
        epsilon: float = 0.9,
        epsilon_decay: float = 1.0,
        epsilon_min: float = 0.05,
        seed: Optional[int] = None,
    ):
        super().__init__(seed)
        self.epsilon_max = epsilon
        self.epsilon = epsilon
        self.epsilon_decay = epsilon_decay
        self.epsilon_min = epsilon_min

        self.samples = samples if isinstance(samples, int) else np.array(samples)

    def decide(self, observation: array):
        self._epsilon_update()
        return self._random.uniform(0, 1) < self.epsilon

    def act(self):
        return self._random.choice(self.samples)
    
    def _epsilon_update(self):
        self.epsilon = max(self.epsilon_min, self.epsilon * self.epsilon_decay)

    def reset_epsilon(self):
        self.epsilon = self.epsilon_max


@AGENTS.register_module()
class CliffWalkDQN(BaseAgent):
    FREE = 0
    FROM = 1
    DEAD = 2
    DEST = 3

    def __init__(
        self,
        device,
        gamma,
        qnet_cfg,
        lr,
        explorer_cfg,
        test_mode: bool = False,
    ):
        super().__init__(device, gamma)

        self.test_mode = test_mode
        self.target_qnet: Module = QNet(**qnet_cfg).to(self.device)

        if not self.test_mode:
            self.explorer: BaseExplorer = EXPLORERS.build(explorer_cfg)
            self.qnet: Module = QNet(**qnet_cfg).to(self.device)
            self.target_qnet.load_state_dict(self.qnet.state_dict())

            self.lr = lr
            self.opt = torch.optim.Adam(self.qnet.parameters(), lr=self.lr)

            self.start_episode = 0

    def take_action(self, observation, **kwargs):
        if self.test_mode:
            return self.test_take_action(observation, **kwargs)
        else:
            return self.train_take_action(observation, **kwargs)

    def train_take_action(self, observation, **kwargs):
        if self.explorer.decide(observation):
            return self.explorer.act()
        else:
            with torch.no_grad():
                obs: Tensor = torch.tensor(observation, dtype=torch.float, device=self.device).view(-1, 1)

                action_dist: Tensor = self.qnet(obs).squeeze(0)

                action_index: int = action_dist.argmax().type(torch.int64).item()

            return action_index

    def test_take_action(self, observation, **kwargs):
        with torch.no_grad():
            obs: Tensor = torch.tensor(observation, dtype=torch.float, device=self.device).view(-1, 1)

            action_dist: Tensor = self.qnet(obs).squeeze(0)

            action_index: int = action_dist.argmax().type(torch.int64).item()

        return action_index

    def update(self, transitions: dict, logger: LoggerHook):
        cur_observations: Tensor = torch.tensor(
            transitions["cur_observation"], dtype=torch.float, device=self.device
        ).view(-1, 1)
        cur_actions: Tensor = torch.tensor(
            transitions["cur_action"], dtype=torch.int64, device=self.device
        ).view(-1, 1)
        next_observations: Tensor = torch.tensor(
            transitions["next_observation"], dtype=torch.float, device=self.device
        ).view(-1, 1)
        rewards: Tensor = torch.tensor(
            transitions["reward"], dtype=torch.float, device=self.device
        ).view(-1, 1)
        terminated: Tensor = torch.tensor(
            transitions["terminated"], dtype=torch.float, device=self.device
        ).view(-1, 1)

        values_predict: Tensor = torch.gather(
            self.qnet(cur_observations),
            dim=1,
            index=cur_actions,
        )
        td_target = rewards + self.gamma * torch.max(
            self.target_qnet(next_observations),
            dim=1,
        )[0].view(-1, 1) * (1 - terminated)

        loss = torch.mean(F.mse_loss(values_predict, td_target))

        self.opt.zero_grad()
        loss.backward()
        self.opt.step()

        if (logger.cur_episode_index + 1) % 10 == 0:
            self.target_qnet.load_state_dict(self.qnet.state_dict())

        return loss.item()

I wrote the environment myself. Every step gives a -1 reward, and falling into the cliff gets -100 reward. The episode ends if 200 steps are taken, or the agent falls into the cliff, or it reaches the DEST. During training and evaluation, the agent always start from the FROM point.

@ENVIRONMENTS.register_module()
class CliffWalkDQNEnv:
    INT = torch.int64
    FLOAT = torch.float

    FREE = 0
    FROM = 1
    DEAD = 2
    DEST = 3

    def __init__(
        self,
        env: List[int],
        device: Optional[str],
        save_path: str = None,
    ):
        self.device = torch.device(device)
        self.env: array = np.array(env, dtype=np.int64)
        assert len(np.transpose(np.nonzero(self.env == self.FROM))) == 1, (
            "Multiple start points found."
        )
        self.env_count: array = np.zeros_like(self.env)

    @property
    def env_shape(self):
        return list(self.env.shape)
    
    @property
    def start_point_index(self):
        coord = np.transpose(np.nonzero(self.env == self.FROM))[0]

        obs_index = _coord2obs_index(coord, self.env_shape)

        return obs_index

    def reset(self) -> array:
        return self.start_point_index

    def step(
        self,
        observation: array,
        action: array,
    ) -> Dict[str, array | int | bool]:
        cur_obs_coord = _obs_index2_coord(observation, self.env_shape)
        self.env_count[cur_obs_coord[0], cur_obs_coord[1]] += 1
        movement = _action_index2movement(action, 2)  # [-1, 0], [+1, 0], [0, -1], [0, +1]

        env_shape: array = np.array(self.env_shape, dtype=np.int64)
        upper_bound: array = env_shape - 1
        lower_bound: array = np.zeros_like(upper_bound, dtype=np.int64)

        next_obs_coord: array = np.clip(
            cur_obs_coord + movement, lower_bound, upper_bound, dtype=np.int64
        )
        next_pos_state = self.env[next_obs_coord[0], next_obs_coord[1]]
        if next_pos_state == self.DEST or next_pos_state == self.DEAD:
            self.env_count[next_obs_coord[0], next_obs_coord[1]] += 1

        if next_pos_state == self.DEAD:
            reward = -100
        # elif next_pos_state == self.DEST:
        #     reward = 100
        else:
            reward = -1

        transition = dict(
            cur_observation=observation,
            cur_action=action,
            next_observation=_coord2obs_index(next_obs_coord, self.env_shape),
            reward=reward,
            terminated=(
                next_pos_state == self.DEST
                or next_pos_state == self.DEAD
            ),
        )

        return transition

The training process is done by a runner.


@RUNNERS.register_module()
class CliffWalkDQNRunner:
    def __init__(self, cfg: Config):
        self.work_dir: str = cfg.work_dir
        self.logger: LoggerHook = HOOKS.build(cfg.logger_cfg)

        self.env = ENVIRONMENTS.build(cfg.env)
        self.agent = AGENTS.build(cfg.agent)
        self.test_mode: bool = self.agent.test_mode
        self.start_episode: int = 0
        self.eval_interval: int = cfg.eval_interval

        # max step number for an episode, exceeding makes truncated True
        self.steps: int = cfg.steps

        if not self.test_mode:
            self.replay_buffer = REPLAY_BUFFERS.build(cfg.replay_buffer)
            if not self.replay_buffer.is_active():
                self.generate_replay_buffer()

            self.save_checkpoint_interval: int = cfg.save_checkpoint_interval

            # max episode number
            self.episodes: int = cfg.episodes

            self.start_episode = self.agent.start_episode
            if "resume_from" in cfg.agent.keys() and cfg.agent.resume_from is not None:
                self.logger.set_cur_episode_index(self.start_episode - 1)

    def generate_replay_buffer(self):
        # generates replay buffer with epsilon greedy
        # until the buffer gets enough experiences
        pass

    def train(self):
        cur_observation = None

        for episode in range(self.start_episode, self.episodes):
            cur_observation = self.env.reset()

            episode_return: Dict[str, float] = dict(
                loss=0,
                reward=0,
                coverage=0,
            )
            terminated = False

            for step in range(self.steps):
                cur_action = self.agent.take_action(
                    cur_observation,
                    episode_index=episode,
                    step_index=step,
                    logger=self.logger,
                    save_dir=self.logger.save_dir,
                )

                cur_transition = self.env.step(
                    cur_observation, cur_action
                )

                cur_observation = cur_transition["next_observation"]
                terminated = cur_transition["terminated"]
                truncated = step == self.steps - 1

                episode_return["reward"] += cur_transition["reward"]

                self.replay_buffer.add(cur_transition)

                if terminated or truncated:
                    break

            episode_return["loss"] = self.agent.update(
                self.replay_buffer.sample(), self.logger
            )

2. Training results

2.1. 3 * 3 world

After training for over 300 episodes, the agent finally reaches the DEST for the first time. Reward and log10_loss graph The training process is quite slow, and the agent seemed not learning anything before the 300th episode if I print out the max probability strategy:

==== episode 270 ====
==== max probability strategy ====
^  <  >  
^  ^  >  
^  x  ^  
==== counts of times that the agent visits corresponding position ====
[[29830  7597  9422]
 [ 2619   183   734]
 [ 5271    62    17]]

As you can see, after 270 episodes, the agent only reaches DEST 17 times. The agent could seldom get these experiences from the buffer, and I wonder if strategies at these seldom-reached positions are still totally random.

The final result is as expected, and the agent finds the optimal path:

==== episode 1000 ====
v  >  v  
>  >  v  
^  x  ^  
========
[[33401  8040 13459]
 [ 8381  1320  3219]
 [15419   154   723]]

2.2 4 * 6 world

A larger environment of 4 * 6 is tested:

    [FREE, FREE, FREE, FREE, FREE, FREE],
    [FREE, FREE, FREE, FREE, FREE, FREE],
    [FREE, FREE, FREE, FREE, FREE, FREE],
    [FROM, FREE, DEAD, DEAD, FREE, DEST],

To accelerate the exploration, this time I rewrite the env.reset() method to allow random start point during training. During evaluation, the agent always start from the FROM position:

# CliffWalkDQNEnv
    @property
    def start_point_index(self):
        coord = np.transpose(np.nonzero(self.env == self.FROM))[0]

        obs_index = _coord2obs_index(coord, self.env_shape)

        return obs_index

    def reset(self, test_mode: bool = False) -> array:
        if test_mode:
            return self.start_point_index
        
        coords = np.transpose(np.nonzero(self.env != self.DEAD))

        obs_index = _coord2obs_index(
            coords[np.random.choice(len(coords))], self.env_shape
        )

        return obs_index

I'm not sure if this trick actually contributes to any acceleration.

==== episode 400 ====
v  v  v  >  v  v  
^  ^  ^  ^  v  v  
^  ^  ^  ^  >  v  
^  ^  x  x  ^  >  
========
[[21054  6877  8290  5902  6691  9866]
 [ 4390  1032   441   177   370  1578]
 [ 1282    95   103    88   131  1935]
 [ 4047  1798    22     9   272   115]]

The agent, during evaluation, reached DEST the first time at episode 515. After that, it found the safer path, and gradually converged to the optimal path.

==== episode 665 ====
>  >  >  >  v  <  
>  >  >  ^  >  <  
>  ^  ^  ^  >  v  
^  ^  x  x  ^  >  
========
[[22312  7619  8602  6605  7017  9961]
 [ 4717  1435   593   355   995  1803]
 [ 2833   411   156   105   389  2197]
 [ 8382  1857    26     9   284   396]]

==== episode 705 ====
>  >  ^  >  <  ^  
>  >  >  >  v  v  
^  ^  ^  ^  >  v  
^  ^  x  x  ^  >  
========
[[22313  7628  8669  8993  7673  9974]
 [ 4721  1446   612   408  1409  2198]
 [ 2838   419   160   111   408  2226]
 [ 8394  1863    27     9   284   429]]

Reward and log10_loss graph

3. My questions

  1. During a long period at the beginning of the training, the agent barely reached the DEST even for a 3 * 3 world (this observation space is REALLY small). Is this a usual phenomenon for DQN?
  2. If I am trying DQN on a really large observation space, how can I speed up this seemingly random period?

Any insights and suggestions are appreciated!

Upvotes: 1

Views: 28

Answers (0)

Related Questions