user13923578
user13923578

Reputation: 11

DQN doesn't make any progress after a little while

Here is my code, its a simple DQN that learns to play snake and i dont know why it stops learning after a little while, for example. it learns that the snake head should hit the wall, but it doesnt learn to eat the fruit, even though i give a reward for getting closer to the fruit and give a GREATER negative reward for going farther away (this is to make the snake understand that it should aim to go for the fruit). But for some reason, the score never goes beyond a 1 or a 2: """ ######################################################## #MAIN.py

    # -*- coding: utf-8 -*-
    """
    Created on Mon Aug 10 13:04:45 2020
    
    @author: Ryan
    """
    
    
    from dq_learning import Agent
    import numpy as np
    import tensorflow as tf
    import snake
    import sys
    import pygame
    import gym
    
    
    
    if __name__ == '__main__':
        observation_space = 31
        action_space = 4
        lr = 0.001 
        n_games = 50000
        steps = 1000
        #env = gym.make("LunarLander-v2")
        #observation_space = env.observation_space.shape
        #action_space = env.action_space.n
        agent = Agent(gamma=0.99, epsilon=1.0, lr=lr, 
                      input_dims=observation_space,
                      n_actions=action_space,
                      batch_size=64)
        scores = []
        eps_history = []
        r = False
          
        for i in range(n_games):    
            score = 0
            #first observation
            observation = [0 for i in range(observation_space)] 
            #observation = env.reset()
            for j in range(steps):
               # env.render()
                
                for evt in pygame.event.get():
                    if evt.type == pygame.QUIT:
                        pygame.quit()
                        sys.exit()
                        
                #actions go from 0 to n_actions - based on the model prediction or random choice
                #action space is the list of all the possible actions
                action = agent.choose_action(observation)
                #print("action: ", action)
                #env.step(action) returns -> new observation, reward, done, info
                observation_, reward, done, info = snake.step(action, 25)
                #observation_, reward, done, info = env.step(action)
                #print(observation_, reward, done, info)
                score += reward
                agent.store_transition(observation, action, reward, observation_, done)
                observation = observation_
                agent.learn()
                if done:
                    break
            print("NEXT GAME")            
            done = False  
            eps_history.append(agent.epsilon)
            scores.append(score)
            
            avg_score = np.mean(scores[-100:])
            
            print("episode: ", i, " scores %.2f" %score,
                  "average score: %.2f" %avg_score, 
                  " epsilon %.2f" %agent.epsilon)
            print("last score: ", scores[-1])
        
    #####################################
    #DQ_LEARNING.PY
    
    # -*- coding: utf-8 -*-
    """
    Created on Tue Aug  4 12:23:14 2020
    
    @author: Ryan
    """
    
    
    import numpy as np
    import tensorflow as tf
    from tensorflow import keras
     
    
    class ReplayBuffer:
        def __init__(self, max_size, input_dims):
            self.mem_size = max_size
            self.mem_cntr = 0
            """
            print("self.mem_size: ", self.mem_size)
            print("*input_dims: ", *input_dims)
            """
            self.state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)
            self.new_state_memory = np.zeros((self.mem_size, input_dims), dtype=np.float32)
            self.action_memory = np.zeros(self.mem_size, np.int32)
            self.reward_memory = np.zeros(self.mem_size, np.float32)
            self.terminal_memory = np.zeros(self.mem_size, np.int32) #done flags
            
        def store_transitions(self, state, action, reward, state_, done):
            """print("storing transactions...")
            print("mem_cntr: ", self.mem_cntr)
            print("mem_size: ", self.mem_size)
            """
            index = self.mem_cntr % self.mem_size
            self.state_memory[index] = state
            self.new_state_memory[index] = state_
            self.reward_memory[index] = reward
            self.action_memory[index] = action
            self.terminal_memory[index] = 1 - int(done)
            self.mem_cntr += 1
        
        def sample_buffer(self, batch_size):
            #print("sampling buffer...")
            max_mem = min(self.mem_cntr, self.mem_size)
            batch = np.random.choice(max_mem, batch_size, replace=False)
            #print("batch:", batch)
            states = self.state_memory[batch]
            states_ = self.new_state_memory[batch]
            rewards = self.reward_memory[batch]
            actions = self.action_memory[batch]
            terminal = self.terminal_memory[batch]
            #print("self.action_mem: ", self.action_memory)
            #print("actions: ", actions)
            
            #print("state action rewards state_, terminal", (states, actions, rewards, states_, terminal))
            return states, actions, rewards, states_, terminal
        
    def build_dqn(lr, n_actions, input_dims, fc1_dims, fc2_dims):
        model = keras.Sequential()
        model.add(keras.layers.Dense(fc1_dims, activation='relu'))
        model.add(keras.layers.Dense(fc2_dims, activation='relu'))
        model.add(keras.layers.Dense(n_actions))
        
        opt = keras.optimizers.Adam(learning_rate=lr)     
        model.compile(optimizer=opt, loss='mean_squared_error')
        
        return model
        
    class Agent():
        def __init__(self, lr, gamma, n_actions, epsilon, batch_size, 
                     input_dims, epsilon_dec=1e-3, epsilon_end=1e-2,
                     mem_size=1e6, fname='dqn_model.h5'):
            self.action_space = [i for i in range(n_actions)]
            self.gamma = gamma
            self.epsilon = epsilon
            self.eps_min = epsilon_end
            self.eps_dec = epsilon_dec
            self.batch_size = batch_size
            self.model_file = fname
            self.memory = ReplayBuffer(int(mem_size), input_dims)
            self.q_eval = build_dqn(lr, n_actions, input_dims, 256, 256)
        def store_transition(self, state, action, reward, new_state, done):
            self.memory.store_transitions(state, action, reward, new_state, done)
        def choose_action(self, observation):
            if np.random.random() < self.epsilon:
                action = np.random.choice(self.action_space)
            else:
                state = np.array([observation])
                actions = self.q_eval.predict(state)
                action = np.argmax(actions)
            return action
        
        def learn(self):
            if self.memory.mem_cntr < self.batch_size:
                return
            states, actions, rewards, states_, dones = \
                self.memory.sample_buffer(self.batch_size)
            
            q_eval = self.q_eval.predict(states)
            q_next = self.q_eval.predict(states_)
            
            q_target = np.copy(q_eval)
            batch_index = np.arange(self.batch_size, dtype=np.int32)
            
            q_target[batch_index, actions] = rewards + \
                self.gamma * np.max(q_next, axis=1)*dones
            
            self.q_eval.train_on_batch(states, q_target)
            
            self.epsilon = self.epsilon - self.eps_dec if self.epsilon > \
                self.eps_min else self.eps_min
        def save_model(self):
            self.q_eval.save(self.model_file)
        def load_model(self):
            self.q_eval =  keras.models.load_model(self.model_file)
            
            
                
    ##########################################
   # snake.py
    
     # -*- coding: utf-8 -*-
"""
Created on Fri Sep  4 14:32:30 2020

@author: Ryan
"""


import pygame
import random
from math import sqrt
import time

class Snakehead:
    def __init__(self, posx, posy, width, height):
        self.posx = posx
        self.posy = posy
        self.width = width
        self.height = height
        self.movement = 'null'
        self.speed = 16
        self.gameover = False 
    def draw(self, Display):     #RGB #coordinates/dimentions
        pygame.draw.rect(Display, [0, 0, 0], [self.posx, self.posy, self.width, self.height])
    def read_input(self, key):
        if key == 0 and key != 1:
            self.movement = 'left'
        elif key == 1 and key != 0:
            self.movement = 'right'
        elif key == 2 and key != 3:
            self.movement = 'up'
        elif key == 3 and key != 2:
            self.movement = 'down'
        print(self.movement)
    def get_pos(self):
        return self.posx, self.posy
    def get_movement(self):
        return self.movement
    def restart(self, ScreenW, ScreenH):
        self.posx = ScreenW / 2 - 16/2
        self.posy = ScreenH / 2 - 16/2
    def move(self, SW, SH):

        if self.movement == 'right':
            self.posx += self.speed # self.posx = self.posx + self.speed
        elif self.movement == 'left':
            self.posx -= self.speed # self.posx = self.posx - self.speed
        elif self.movement == 'up':
            self.posy -= self.speed # self.posy = self.posy - self.speed
        elif self.movement == 'down':
            self.posy += self.speed # self.posy = self.posy + self.speed


class Food:
    def __init__(self, posx, posy, width, height):
        self.posx = posx
        self.posy = posy
        self.width = width
        self.height = height
        self.red = random.randint(155, 255)
    def draw(self, Display):
        pygame.draw.rect(Display, [self.red, 0, 0], [self.posx, self.posy, self.width, self.height])
    def get_pos(self):
        return self.posx, self.posy
    def respawn(self, ScreenW, ScreenH):
        self.posx = random.randint(1, (ScreenW - 16)/16) * 16 
        self.posy = random.randint(1, (ScreenH - 16)/16) * 16 
        self.red = random.randint(155, 255)
    

class Tail:
    def __init__(self, posx, posy, width, height):
        self.width = width
        self.height = height
        self.posx = posx
        self.posy = posy
        self.RGB = [random.randint(0, 255) for i in range(3)]
        
    def draw(self, Diplay):
        pygame.draw.rect(Diplay, self.RGB, [self.posx, self.posy, 16, 16])

    def move(self, px, py):
        self.posx = px
        self.posy = py

    def get_pos(self):
        return self.posx, self.posy


ScreenW = 720
ScreenH = 720

sheadX = 0
sheadY = 0

fX = 0
fY = 0

counter = 0




pygame.init()
pygame.display.set_caption("Snake Game")

Display = pygame.display.set_mode([ScreenW, ScreenH])
Display.fill([255, 255, 255]) #RGB white

black = [0, 0, 0]
font = pygame.font.SysFont(None, 30)
score = font.render("Score: 0", True, black)

shead = Snakehead(ScreenW / 2 - 16/2, ScreenH / 2 - 16/2, 16, 16)
f = Food(random.randint(0, (ScreenW - 16)/16) * 16 - 8, random.randint(0, (ScreenH - 16)/16) * 16, 16, 16)
tails = []

Fps = 60
timer_clock = pygame.time.Clock()
previous_distance = 0
d = 0

def step(action, observation_space):
    global score, counter, tails, shead, gameover, previous_distance, d
    shead.gameover = False
    observation_, reward, done, info = [0 for i in range(observation_space+6)], 0, 0, 0
    Display.fill([255, 255, 255])
    shead.read_input(action)
    sheadX, sheadY = shead.get_pos()
    fX, fY = f.get_pos()
    #detect collision
    if sheadX + 16 > fX and sheadX < fX + 16:
        if sheadY + 16 > fY and sheadY < fY + 16:
            #collision
            f.respawn(ScreenW, ScreenH)
            counter += 1 # counter = counter + 1
            score = font.render("Score: " + str(counter), True, black)
            if len(tails) == 0:
                tails.append(Tail(sheadX, sheadY, 16, 16))
            #tails.append(tail.Tail(sheadX, sheadY, 16, 16, shead.get_movement()))
            else:
                tX, tY = tails[-1].get_pos()
                tails.append(Tail(tX, tY, 16, 16))
            reward = 100
            print(tails)

    for i in range(len(tails)):
        try:
            tX, tY = tails[i].get_pos()
            #print("tx: ", tX, " ty: ", tY)
            sX, sY = shead.get_pos()
            #print("Sx: ", sX, " sy: ", sY)
            if i != 0 and i != 1:
                #print("more than 2 tails")
                if tX == sX and tY == sY:
                    print("collision")
                    #collision
                    shead.restart(ScreenW, ScreenH)
                    tails.clear()
                    counter = 0
                    Display.blit(score, (10, 10))
                    pygame.display.flip()
                    pygame.display.update()
                    reward = -300
                    shead.gameover = True
                    print("lost-3")
        except:
            shead.restart(ScreenW, ScreenH)
            tails.clear()
            counter = 0
            reward = -300
            shead.gameover = True
            print("lost-0")

        
    sX, sY = shead.get_pos()
    if sX < 0 or sX + 16 > ScreenW:
            shead.restart(1280, 720)
            counter = 0
            Display.blit(score, (10, 10))
            pygame.display.flip()
            pygame.display.update()
            tails.clear()
            print("lost-1")
            reward = -200
            shead.gameover = True
            #restart
    elif sY < 0 or sY + 16 > ScreenH:
        shead.restart(1280, 720)
        counter = 0
        Display.blit(score, (10, 10))
        pygame.display.flip()
        pygame.display.update()
        tails.clear()
        reward = -200
        shead.gameover = True
        print("lost-2")
            #restart

    for i in range(1, len(tails)):
        tX, tY = tails[len(tails) - i - 1].get_pos() # y = b - x
        tails[len(tails) - i].move(tX, tY) 
    if len(tails) > 0:
        tX, tY = shead.get_pos()
        tails[0].move(tX, tY)
    shead.move(ScreenW, ScreenH)
    shead.draw(Display)
    Display.blit(score, (10, 10))
    for tail in tails:
        tail.draw(Display)
    f.draw(Display)
    pygame.display.flip()
    pygame.display.update()
    timer_clock.tick(Fps)
    #observation, done
    done = shead.gameover
    hx, hy = shead.get_pos()
    hx /= ScreenW
    hy /= ScreenH

    fx, fy = f.get_pos()
    fx /= ScreenW
    fy /= ScreenH
    

    observation_[0] = abs(hx - fx)
    observation_[1] = abs(hy - fy)
    previous_distance = d
    d = sqrt((fx - hx)**2 + (fy - hy)**2)
    #print("distance: ", d)
    observation_[2] = d
    observation_[3] = 0
    #print("observation_[4]: ", observation_[4])
    observation_[4] = hx
    observation_[5] = hy
    c = 6
    xlist = []
    ylist = []
    for t in tails:         
        tx, ty = t.get_pos()
        tx /= 16
        ty /= 16
        xlist.append(tx)
        ylist.append(ty)
    l = int(sqrt(observation_space))
    startX, startY = shead.get_pos()
    startX /= 16
    startY /= 16
    m = (l-1)/2
    #print("xlist:" , xlist)
    #print("ylist:", ylist)
    #print("startX: ", startX)
    #print("startY: ", startY)
    #print("m: ", m)
    #print("l: ", l)
    for x in range(l):
        for y in range(l):
            found = False
            #print("position: (", int(startX) - m + x, ",", int(startY) - m + y, ")")
            for i in range(len(xlist)):
                """print("i:", i)
                print("pos: ", startX - m + x)
                print("j: ", j)
                print("pos: ", startY - m + y)
                """
                #print("current iteration: (", int(xlist[i]), ",", int(ylist[i]), ")")
                if int(xlist[i]) == int(startX) - m + x and int(ylist[i]) == int(startY) - m + y:
                    #print("found a match")
                    observation_[c] = 1
                    #print("c is: ", c)
                    #print("observation_[c] is: ", observation_[c])
                    found = True
                    break
            if not found:
                #print("set to 0")
                observation_[c] = 0
            #print("increasing c...")
            c += 1
            
    print("reward: ", reward)
    print("c_reward: ", counter*10)     
    d_reward = 10 if d < previous_distance else - 100 
    print("d_reward: ", d_reward)       
    print(observation_, reward + d_reward + counter*10, done, 0)
    
    return observation_, reward, done, 0



    

Upvotes: 1

Views: 456

Answers (1)

Jose
Jose

Reputation: 86

The reward function looks fine to me.

However, you say "I give a reward for getting closer to the fruit and give a GREATER negative reward for going farther away" but in the code it does not look like you use d_reward:

print("reward: ", reward)
print("c_reward: ", counter*10)     
d_reward = 10 if d < previous_distance else - 100 
print("d_reward: ", d_reward)       
print(observation_, reward + d_reward + counter*10, done, 0)

return observation_, reward, done, 0

This is fine, as d_reward is definitely not necessary. Only giving positive reward for eating the apple, negative for dying and 0 otherwise is enough.

I suspect that the issue is in your state representation. Only by looking at your state, it is impossible for your agent to know which direction it should go, as the information of the apple position relative to the head is given with absolute values.

As an example lets say that your board is as follows:

[food,  head,  empty]

Your observation would be:

[1, 0, 1, 0, 1, 0]

But if your board is:

[empty, head,  food]

The observation is the same:

[1, 0, 1, 0, 1, 0]

This is a problem. With a given input, the same action could be good or bad, whithout any way of knowing it. This makes learning impossible. In our example, for the input [1, 0, 1, 0, 1, 0], our network could move towards (or away) from both: left and right , never converging in any action.

This is because in your training data you will have examples of that input where moving to the left is good, others where it is neutral, others where it is bad, and examples of that input where right is good, neutral, bad etc.

I would recommend to encode more information in your state (or observation). I suggest something like this (which I took from a project of mine, you'll need to adapt it):

def get_state(self):
    head = self.snake[0]

    danger_top = head.y == self.board_dim.y - 1 or Point(head.x, head.y + 1) in self.snake
    danger_bot = head.y == 0 or Point(head.x, head.y - 1) in self.snake
    danger_right = head.x == self.board_dim.x - 1 or Point(head.x + 1, head.y) in self.snake
    danger_left = head.x == 0 or Point(head.x - 1, head.y) in self.snake

    apple_top = head.y < self.apple.y
    apple_bot = head.y > self.apple.y
    apple_right = head.x < self.apple.x
    apple_left = head.x > self.apple.x

    return np.array([
        danger_top,
        danger_bot,
        danger_right,
        danger_left,
        apple_top,
        apple_bot,
        apple_right,
        apple_left], dtype=int)

Please, let me know if I did miss some part of your code or if you have any doubt. Thank you in advance.

Upvotes: 2

Related Questions