Susmit Agrawal
Susmit Agrawal

How to build a basic neural network in Java?

I am trying to build a basic neural network to compute logical XOR function, in Java.

The network has two input neurons, one hidden layer with three neurons and a single output neuron.

But after a few iterations, the error in output becomes NaN.

I have gone through other implementations and tutorials for implementing neural networks, but I cannot find the error. I feel the issue lies in my backward function.

Please help me understand where I went wrong.

My code:

import org.ejml.simple.SimpleMatrix;

import java.util.ArrayList;
import java.util.List;
import java.util.Random;

// SimpleMatrix constructor format: SimpleMatrix(rows, cols)
//The layers are represented as a matrix with 1 row and multiple columns (row vector)
public class Network {
    private SimpleMatrix inputs, outputs, hidden, W1, W2, predicted;
    static final double LEARNING_RATE = 0.3;

    Network(List<double[]> ips, List<double[]> ops){
        hidden = new SimpleMatrix(1, 3);
        W1 = new SimpleMatrix(ips.get(0).length, hidden.numCols());
        W2 = new SimpleMatrix(hidden.numCols(), ops.get(0).length);

        for(int i=0;i<5000;i++){
            for(int j=0;j<ips.size();j++){
                train(ips.get(j), ops.get(j));

    //Prints output matrix
    SimpleMatrix predict(double[] ip){
        SimpleMatrix bkpInputs = inputs.copy();
        SimpleMatrix bkpOutputs = outputs.copy();

        inputs = new SimpleMatrix(1, ip.length);
        inputs.setRow(0, 0, ip);

        inputs = bkpInputs;
        outputs = bkpOutputs;

        return predicted;

    void train(double[] inputs, double[] outputs){
        this.inputs = new SimpleMatrix(1, inputs.length);
        this.inputs.setRow(0, 0, inputs);
        this.outputs = new SimpleMatrix(1, outputs.length);
        this.predicted = new SimpleMatrix(1,outputs.length);


    private void initWeights(SimpleMatrix... W){
        Random random = new Random();
        for (SimpleMatrix aW : W) {
            for (int i = 0; i < aW.numRows(); i++)
                for (int j = 0; j < aW.numCols(); j++)
                    aW.set(i, j, random.nextDouble());

    //Using logistic function
    double sigmoid(double x){
        return (1/(1+Math.exp(-x)));

    double sigmoidPrime(double x){
        return sigmoid(x)/(1-sigmoid(x));

    void forward(){
        hidden = inputs.mult(W1);
        for(int i=0;i<hidden.numCols();i++){
            double x = sigmoid(hidden.get(0,i));
        predicted = hidden.mult(W2);
        for(int i=0;i<predicted.numRows();i++){
            for(int j=0;j<predicted.numCols();j++){
                predicted.set(i,j, sigmoid(predicted.get(i,j)));

    void backward(){

        //Error in output
        double o_error = 0.0;
        //Error functions I tried: (1/2)( (predicted-actual) ^ 2) and (predicted - actual)
        for(int i=0;i<outputs.numCols();i++)
            o_error += (predicted.get(0, i)-outputs.get(0, i));//Math.pow(predicted.get(0, i)-outputs.get(0, i), 2)/2;
        //Checking output error

        //Output deltas
        SimpleMatrix o_deltas = new SimpleMatrix(1, outputs.numCols());
        for(int i=0;i<outputs.numCols();i++)
            o_deltas.set(0, i, o_error*sigmoidPrime(predicted.get(0, i))); 

        //Error in hidden layer and deltas
        double h_error =;
        SimpleMatrix h_deltas = new SimpleMatrix(1, hidden.numCols());
        for(int i=0;i<hidden.numCols();i++)
            h_deltas.set(0, i, h_error*sigmoidPrime(hidden.get(0, i)));

        //Hidden->Output layer update
        SimpleMatrix W2_delta = W2.mult(o_deltas.transpose());
        for(int i=0;i<W2.numRows();i++){
            for(int j=0;j<W2.numCols();j++){
                W2.set(i,j, W2.get(i,j) + LEARNING_RATE*W2_delta.get(i, 0));

        //Input->Hidden layer update
        SimpleMatrix W1_delta = W1.mult(h_deltas.transpose());
        for(int i=0;i<W1.numRows();i++){
            for(int j=0;j<W1.numCols();j++){
                W1.set(i,j, W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));

    public static void main(String[] args){
        double[][] ips = {

        double[][] ops = {

        List<double[]> ip = new ArrayList<>();
        List<double[]> op = new ArrayList<>();

        for(int i=0;i<ips.length;i++){

        double[] testip = {1,0};
        Network n = new Network(ip,op);

Answers (2)


So it may not be whats causing you the issue but I noticed:

W1.get(i,j) + LEARNING_RATE*W1_delta.get(i, 0));

When you update the weights. I think the correct formulae is: The

So your code should be:

W1(i,j) += LEARNING_RATE * W1_delta.get(i, 0) *  <output from the connected node>;

It may not solve it but its worth a try!

Try with lower learning rates. When error is NaN it often means that your cost/error function has exploded. Try something in the range of [10^-3, 10^-5].

