Reputation: 21
I am trying to implement DQN in openai-gym's "lunar lander" environment.
It shows no sign of converging after 3000 episodes for training. (for comparison, a very simple policy gradient method converges after 2000 episodes)
I went through my code for several times but can't find where's wrong. I hope if someone here can point out where the problem is. Below is my code:
I use a simple fully-connected network:
class Net(nn.Module):
def __init__(self) -> None:
self.main = nn.Sequential(
nn.Linear(8, 16),
nn.Linear(16, 16),
nn.Linear(16, 4)
def forward(self, state):
return self.main(state)
I use epsilon greedy when choosing actions, and the epsilon(start from 0.5) decreases exponentially overtime:
def sample_action(self, state):
self.epsilon = self.epsilon * 0.99
action_probs = self.network_train(state)
random_number = random.random()
if random_number < (1-self.epsilon):
action = torch.argmax(action_probs, dim=-1).item()
action = random.choice([0, 1, 2, 3])
return action
When training, I use a replay buffer, batch size of 64, and gradient clipping:
def learn(self):
if len(self.buffer) >= BATCH_SIZE:
self.learn_counter += 1
transitions = self.buffer.sample(BATCH_SIZE)
batch = Transition(*zip(*transitions))
state = torch.from_numpy(np.concatenate(batch.state)).reshape(-1, 8)
action = torch.tensor(batch.action).reshape(-1, 1)
reward = torch.tensor(batch.reward).reshape(-1, 1)
state_value = self.network_train(state).gather(1, action)
next_state = torch.from_numpy(np.concatenate(batch.next_state)).reshape(-1, 8)
next_state_value = self.network_target(next_state).max(1)[0].reshape(-1, 1).detach()
loss = F.mse_loss(state_value.float(), (self.DISCOUNT_FACTOR*next_state_value + reward).float())
for param in self.network_train.parameters():, 1)
I also use a target network, its parameters are updated every 100 timesteps:
def update_network_target(self):
if (self.learn_counter % 100) == 0:
BTW, I use a Adam optimizer and LR of 1e-3.
Upvotes: 0
Views: 442
Reputation: 21
Solved. Apparently the freq of updating target network is too high. I set it to every 10 episodes and fixed the problem.
Upvotes: 0