Reputation: 7019
I'm pretty new to deep learning and neural networks and trying to implement an agent that would be able to play my simple game
So the goal is to get the highest possible score (sum of cells visited) while reaching towards the destination (orange cell) within steps available (always gte distance from the player to the finish cell).
Model for my network is really simple (I'm using tflearn
)
network = input_data(shape=[None, 13, 1], name='input')
network = fully_connected(
network,
13**2,
activation='relu'
)
network = fully_connected(network, 1, activation='linear')
network = regression(
network,
optimizer='adam',
learning_rate=self.lr,
loss='mean_square',
name='target',
)
model = tflearn.DNN(network, tensorboard_dir='log')
where 13
is a number of features I am able to extract from the game state. But the resulting model gives really bad behaviour when playing
[default] INFO:End the game with a score: 36
[default] INFO:Path: up,up,up,up,up,up,up,up,up,up,up,up,up,up,up
So I want to figure out what important parts I have missed and have some open questions to clarify:
Training Step: 3480 | total loss: 0.11609 | time: 4.922s
| Adam | epoch: 001 | loss: 0.11609 -- iter: 222665/222665
I understand that this is a slightly open question and it might be inappropriate to post it here, so I'll appreciate any kind of guidance or general comments.
Upvotes: 0
Views: 2733
Reputation: 2975
Traditionally, reinforcement learning was limited to only solving discrete state discrete action problems because continuous problems caused the problem of "curse of dimensionality". For example, lets say a robot arm can move between 0 - 90 degrees. That means you need an action for angle = 0, 0.00001, 0.00002, ..., which is infeasible for traditional tabular based RL.
To solve this problem, we had to teach RL that 0.00001 and 0.00002 are more or less the same. To achieve this, we need to use function approximation such as neural networks. These approximations' aim is to approximate the Q-matrix in tabular RL and capture the policy (i.e., the choices of the robot). However, even up to today, non-linear function approximation are known to be extremely difficult to train. The first time NNs were successful in RL was by David Silver and his deterministic policy gradient (2014). His approach was to map states to actions directly, without the Q-value. But the loss function of the neural network would be guided by rewards.
To answer the original question of "how to properly reward the NN":
Here is the original paper: http://proceedings.mlr.press/v32/silver14.pdf
The issue with monte carlo methods is their high variance, because each trajectory can be highly different from others. So then modern RL (late 2015 - now) uses an actor-critic method where the actor is the above algorithm, but there's another critic that approximates the Q-matrix using a neural network. This critic attempts to stabilize the actors learning by giving it information after each episode. Thus, reducing the variance of the actor.
The 2 most popular algorithms are: Deep deterministic policy gradient and proximal policy optimization.
I would recommend you get familiar with deterministic policy gradients first before trying the other ones.
Upvotes: 3
Reputation: 7019
Ok, so you should solve the problem with proper tools. As mentioned in comments the right way to do that is by using Reinforcement Learning. Here's the algorithm that returns optimal policy for our environment (based on Q-learning
)
states_space_size = (game.field.leny - 2)*(game.field.lenx - 2)
actions_space_size = len(DIRECTIONS)
QSA = np.zeros(shape=(states_space_size, actions_space_size))
max_iterations = 80
gamma = 1 # discount factor
alpha = 0.9 # learning rate
eps = 0.99 # exploitation rate
s = 0 # initial state
for i in range(max_iterations):
# explore the world?
a = choose_an_action(actions_space_size)
# or not?
if random.random() > eps:
a = np.argmax(QSA[s])
r, s_ = perform_action(s, a, game)
qsa = QSA[s][a]
qsa_ = np.argmax(QSA[s_])
QSA[s][a] = qsa + alpha*(r + gamma*qsa_ - qsa)
# change state
s = s_
print(QSA)
Here's more detailed explanation with a simplified example of how to achieve this result.
Upvotes: 0
Reputation: 7019
Not an answer to the question above but a good place to start to obtain valueable information for your specific network here
Upvotes: 0