
Reputation: 1

How to update the weights in my q-network (deeplearning4j)

I want to create a deep q network with deeplearning4j, but can not figure out how to update the weights of my neural network using the calculated loss.

(i was mainly following this Artical)

public class DDQN {

    private static final double learningRate = 0.01;

    private final MultiLayerNetwork qnet;
    private final MultiLayerNetwork tnet;

    private final ReplayMemory mem = new ReplayMemory(20000);

    private final Batch batch = new Batch(1000);

    public DDQN(int input, int hidden, int output) {

        ListBuilder conf = new NeuralNetConfiguration.Builder().seed(Rnd.seed).weightInit(WeightInit.XAVIER)
                .updater(new Nesterovs(learningRate, 0.9)).list()
                .layer(new DenseLayer.Builder().nIn(input).nOut(hidden).activation(Activation.IDENTITY).build())
                .layer(new DenseLayer.Builder().nIn(input).nOut(hidden).activation(Activation.IDENTITY).build())
                .layer(new DenseLayer.Builder().nIn(hidden).nOut(output).activation(Activation.IDENTITY)

        qnet = new MultiLayerNetwork(;

        tnet = qnet.clone();

    public INDArray tmpState = null;
    public int tmpAction = -1;

    public int getAction(double[] state) {
        tmpState = Nd4j.createFromArray(new double[][] { state });
        tmpAction = tnet.predict(tmpState)[0];
        return tmpAction;

    public void addResult(double reward, INDArray newState) {
        mem.add(tmpState, tmpAction, reward, newState);

    public void train(int size) {
        for (int i = 0; i < batch.size(); i++) {
            // get q value of choosen action
            INDArray out = qnet.output(batch.states[i]);
            double q0 = out.getRow(0).getDouble(batch.actions[i]);
            // get highest q value of the next state
            out = tnet.output(batch.newStates[i]);
            double q1 = out.maxNumber().doubleValue();
            // calc mse
            double err = q0 - (batch.rewards[i] + q1);
            double mse = err * err;
            // update neural net
            // ??????

Replay Memory: (stores what the ai experienced for later training)

public class ReplayMemory {

    private final INDArray[] states;
    private final int[] actions;
    private final double[] rewards;
    private final INDArray[] newStates;
    private int pos = 0;
    private boolean filled = false;
    public ReplayMemory(int size) {
        states = new INDArray[size];
        actions = new int[size];
        rewards = new double[size];
        newStates = new INDArray[size];
    public void fillBatch(Batch b) {
        final int max = filled ? states.length : pos+1;
        int r;
        for(int i=0; i<b.states.length; i++) {
            r = Rnd.r.nextInt(max);
            b.states[i] = states[r];
            b.actions[i] = actions[r];
            b.rewards[i] = rewards[r];
            b.newStates[i] = newStates[r];
    public void add(INDArray state, int action, double reward, INDArray newState) {
        this.states[pos] = state;
        this.actions[pos] = action;
        this.rewards[pos] = reward;
        this.newStates[pos] = newState;
        if(++pos == this.size()) {
            pos = 0;
            filled = true;
    public int size() {
        return states.length;

Batch: (temporary stores the current batch of experiences during training)

public class Batch {
    public final INDArray[] states;
    public final int[] actions;
    public final double[] rewards;
    public final INDArray[] newStates;
    public Batch(int size) {
        states = new INDArray[size];
        actions = new int[size];
        rewards = new double[size];
        newStates = new INDArray[size];
    public int size() {
        return states.length;

I already tried using google and reading documentation, with no luck.

Upvotes: 0

Views: 88

Answers (1)

Adam Gibson
Adam Gibson

Reputation: 3205

At one point we were maintaining a reinforcement library called rl4j.

The project is retired now but I'd be happy to give you a few tips from how we did weight updates.

We built a lot of tooling around what we call external gradients/losses.

Here's calculating the gradients:

    public Gradients computeGradients(FeaturesLabels featuresLabels) {
        Collection<TrainingListener> iterationListeners = mln.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener l : iterationListeners) {
        Gradients result = new Gradients(featuresLabels.getBatchSize());
        result.putGradient(CommonGradientNames.QValues, mln.gradient());
        return result;

You can see how we did it in updating DQN here:

In short:

    public void applyGradients(Gradients gradients) {
        Gradient qValues = gradients.getGradient(CommonGradientNames.QValues);

        MultiLayerConfiguration mlnConf = mln.getLayerWiseConfigurations();
        int iterationCount = mlnConf.getIterationCount();
        int epochCount = mlnConf.getEpochCount();
        mln.getUpdater().update(mln, qValues, iterationCount, epochCount, (int)gradients.getBatchSize(), LayerWorkspaceMgr.noWorkspaces());
        Collection<TrainingListener> iterationListeners = mln.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener listener : iterationListeners) {
                listener.iterationDone(mln, iterationCount, epochCount);
        mlnConf.setIterationCount(iterationCount + 1);

    public void applyGradient(Gradient[] gradient, int batchSize) {
        MultiLayerConfiguration mlnConf = mln.getLayerWiseConfigurations();
        int iterationCount = mlnConf.getIterationCount();
        int epochCount = mlnConf.getEpochCount();
        mln.getUpdater().update(mln, gradient[0], iterationCount, epochCount, batchSize, LayerWorkspaceMgr.noWorkspaces());
        Collection<TrainingListener> iterationListeners = mln.getListeners();
        if (iterationListeners != null && iterationListeners.size() > 0) {
            for (TrainingListener listener : iterationListeners) {
                listener.iterationDone(mln, iterationCount, epochCount);
        mlnConf.setIterationCount(iterationCount + 1);

In our case what we did was just subtract the parameters based on the calculated gradient. I'd be happy to expand on this as well.

Upvotes: 0

Related Questions