Reputation: 1
I want to create a deep q network with deeplearning4j, but can not figure out how to update the weights of my neural network using the calculated loss.
(i was mainly following this Artical)
public class DDQN {
private static final double learningRate = 0.01;
private final MultiLayerNetwork qnet;
private final MultiLayerNetwork tnet;
private final ReplayMemory mem = new ReplayMemory(20000);
private final Batch batch = new Batch(1000);
public DDQN(int input, int hidden, int output) {
ListBuilder conf = new NeuralNetConfiguration.Builder().seed(Rnd.seed).weightInit(WeightInit.XAVIER)
.updater(new Nesterovs(learningRate, 0.9)).list()
.layer(new DenseLayer.Builder().nIn(input).nOut(hidden).activation(Activation.IDENTITY).build())
.layer(new DenseLayer.Builder().nIn(input).nOut(hidden).activation(Activation.IDENTITY).build())
.layer(new DenseLayer.Builder().nIn(hidden).nOut(output).activation(Activation.IDENTITY)
.build());
qnet = new MultiLayerNetwork(conf.build());
qnet.init();
tnet = qnet.clone();
}
public INDArray tmpState = null;
public int tmpAction = -1;
public int getAction(double[] state) {
tmpState = Nd4j.createFromArray(new double[][] { state });
tmpAction = tnet.predict(tmpState)[0];
return tmpAction;
}
public void addResult(double reward, INDArray newState) {
mem.add(tmpState, tmpAction, reward, newState);
}
public void train(int size) {
mem.fillBatch(batch);
for (int i = 0; i < batch.size(); i++) {
// get q value of choosen action
INDArray out = qnet.output(batch.states[i]);
double q0 = out.getRow(0).getDouble(batch.actions[i]);
// get highest q value of the next state
out = tnet.output(batch.newStates[i]);
double q1 = out.maxNumber().doubleValue();
// calc mse
double err = q0 - (batch.rewards[i] + q1);
double mse = err * err;
// update neural net
// ??????
}
}
}
Replay Memory: (stores what the ai experienced for later training)
public class ReplayMemory {
private final INDArray[] states;
private final int[] actions;
private final double[] rewards;
private final INDArray[] newStates;
private int pos = 0;
private boolean filled = false;
public ReplayMemory(int size) {
states = new INDArray[size];
actions = new int[size];
rewards = new double[size];
newStates = new INDArray[size];
}
public void fillBatch(Batch b) {
final int max = filled ? states.length : pos+1;
int r;
for(int i=0; i<b.states.length; i++) {
r = Rnd.r.nextInt(max);
b.states[i] = states[r];
b.actions[i] = actions[r];
b.rewards[i] = rewards[r];
b.newStates[i] = newStates[r];
}
}
public void add(INDArray state, int action, double reward, INDArray newState) {
this.states[pos] = state;
this.actions[pos] = action;
this.rewards[pos] = reward;
this.newStates[pos] = newState;
if(++pos == this.size()) {
pos = 0;
filled = true;
}
}
public int size() {
return states.length;
}
}
Batch: (temporary stores the current batch of experiences during training)
public class Batch {
public final INDArray[] states;
public final int[] actions;
public final double[] rewards;
public final INDArray[] newStates;
public Batch(int size) {
states = new INDArray[size];
actions = new int[size];
rewards = new double[size];
newStates = new INDArray[size];
}
public int size() {
return states.length;
}
}
I already tried using google and reading documentation, with no luck.
Upvotes: 0
Views: 88
Reputation: 3205
At one point we were maintaining a reinforcement library called rl4j.
The project is retired now but I'd be happy to give you a few tips from how we did weight updates.
We built a lot of tooling around what we call external gradients/losses.
Here's calculating the gradients:
@Override
public Gradients computeGradients(FeaturesLabels featuresLabels) {
mln.setInput(featuresLabels.getFeatures().get(0));
mln.setLabels(featuresLabels.getLabels(CommonLabelNames.QValues));
mln.computeGradientAndScore();
Collection<TrainingListener> iterationListeners = mln.getListeners();
if (iterationListeners != null && iterationListeners.size() > 0) {
for (TrainingListener l : iterationListeners) {
l.onGradientCalculation(mln);
}
}
Gradients result = new Gradients(featuresLabels.getBatchSize());
result.putGradient(CommonGradientNames.QValues, mln.gradient());
return result;
}
You can see how we did it in updating DQN here: https://github.com/deeplearning4j/deeplearning4j/blob/1.0.0-M1.1/rl4j/rl4j-core/src/main/java/org/deeplearning4j/rl4j/network/dqn/DQN.java
In short:
@Override
public void applyGradients(Gradients gradients) {
Gradient qValues = gradients.getGradient(CommonGradientNames.QValues);
MultiLayerConfiguration mlnConf = mln.getLayerWiseConfigurations();
int iterationCount = mlnConf.getIterationCount();
int epochCount = mlnConf.getEpochCount();
mln.getUpdater().update(mln, qValues, iterationCount, epochCount, (int)gradients.getBatchSize(), LayerWorkspaceMgr.noWorkspaces());
mln.params().subi(qValues.gradient());
Collection<TrainingListener> iterationListeners = mln.getListeners();
if (iterationListeners != null && iterationListeners.size() > 0) {
for (TrainingListener listener : iterationListeners) {
listener.iterationDone(mln, iterationCount, epochCount);
}
}
mlnConf.setIterationCount(iterationCount + 1);
}
public void applyGradient(Gradient[] gradient, int batchSize) {
MultiLayerConfiguration mlnConf = mln.getLayerWiseConfigurations();
int iterationCount = mlnConf.getIterationCount();
int epochCount = mlnConf.getEpochCount();
mln.getUpdater().update(mln, gradient[0], iterationCount, epochCount, batchSize, LayerWorkspaceMgr.noWorkspaces());
mln.params().subi(gradient[0].gradient());
Collection<TrainingListener> iterationListeners = mln.getListeners();
if (iterationListeners != null && iterationListeners.size() > 0) {
for (TrainingListener listener : iterationListeners) {
listener.iterationDone(mln, iterationCount, epochCount);
}
}
mlnConf.setIterationCount(iterationCount + 1);
}
In our case what we did was just subtract the parameters based on the calculated gradient. I'd be happy to expand on this as well.
Upvotes: 0