Reputation: 11
I recently wrote a neural network using a series of videos in youtube, the channel was coding train. It was written in js I wrote mine in java. It is working correct sometimes but other times I'm getting NaN as outputs and I can's figure out why?
Can anyone help? There is a Matrix class for some matrix math and neural network class it self with a test problem. First output is 1 if 0's are more than 1's and second one is 1 if else.
edit: I found where the problem is but i still can't figure out why it happens?! in happens in my static dot product method in Matrix class. sometimes one or both matrices data are NaN!
edit 2: I checked, inputs are valid in constructor but in feedForward method they are NaN sometimes!!! could it be because I'm using a 10 years old laptop?! because the code doesn't seem to have any problem.
solved: I found the problem! in feedforward I didn't map sigmoid for output Matrix -_-
public class NeuralNetwork {
//private int inputNodes, hiddenNodes, outputNodes;
private Matrix weightsIH, weightsHO, biasH, biasO;
private double learningRate = 0.1;
public NeuralNetwork(int inputNodes, int hiddenNodes, int outputNodes) {
//this.inputNodes = inputNodes;
//this.hiddenNodes = hiddenNodes;
//this.outputNodes = outputNodes;
weightsIH = new Matrix(hiddenNodes, inputNodes);
weightsHO = new Matrix(outputNodes, hiddenNodes);
weightsIH.randomize();
weightsHO.randomize();
biasH = new Matrix(hiddenNodes, 1);
biasO = new Matrix(outputNodes, 1);
biasH.randomize();
biasO.randomize();
}
public void setLearningRate(double learningRate) {
this.learningRate = learningRate;
}
public double sigmoid(double x) {
return 1 / (1 + Math.exp(-x));
}
public double dsigmoid(double y) {
return y * (1 - y);
}
public double[] feedForward(double[] inputArray) throws Exception {
Matrix inputs = Matrix.fromArray(inputArray);
Matrix hidden = Matrix.dot(weightsIH, inputs);
hidden.add(biasH);
hidden.map(f -> sigmoid(f));
Matrix output = Matrix.dot(weightsHO, hidden);
output.add(biasO);
return output.toArray();
}
public void train(double[] inputArray, double[] targetsArray) throws Exception {
Matrix targets = Matrix.fromArray(targetsArray);
// feed forward algorithm //
Matrix inputs = Matrix.fromArray(inputArray);
Matrix hidden = Matrix.dot(weightsIH, inputs);
hidden.add(biasH);
hidden.map(f -> sigmoid(f));
Matrix outputs = Matrix.dot(weightsHO, hidden);
outputs.add(biasO);
// feed forward algorithm //
// Calculate outputs ERRORS
Matrix outputErrors = Matrix.subtract(targets, outputs);
// Calculate outputs Gradients
Matrix outputsGradients = Matrix.map(outputs, f -> dsigmoid(f));
outputsGradients.multiply(outputErrors);
outputsGradients.multiply(learningRate);
// Calculate outputs Deltas
Matrix hidden_t = Matrix.transpose(hidden);
Matrix weightsHO_deltas = Matrix.dot(outputsGradients, hidden_t);
// adjust outputs weights
weightsHO.add(weightsHO_deltas);
// adjust outputs bias
biasO.add(outputsGradients);
// Calculate hidden layer ERRORS
Matrix weightsHO_t = Matrix.transpose(weightsHO);
Matrix hiddenErrors = Matrix.dot(weightsHO_t, outputErrors);
// Calculate hidden Gradients
Matrix hiddenGradients = Matrix.map(hidden, f -> dsigmoid(f));
hiddenGradients.multiply(hiddenErrors);
hiddenGradients.multiply(learningRate);
// Calculate hidden Deltas
Matrix inputs_t = Matrix.transpose(inputs);
Matrix weightsIH_deltas = Matrix.dot(hiddenGradients, inputs_t);
// adjust hidden weights
weightsIH.add(weightsIH_deltas);
// adjust hidden bias
biasH.add(hiddenGradients);
}
public static void print(double[] data) {
for (double d : data) {
System.out.print(d + " ");
}
System.out.println();
}
public static void main(String[] args) {
NeuralNetwork nn = new NeuralNetwork(3, 4, 2);
double[][] trainingInputs = {{0, 0, 0}, {0, 0, 1}, {0, 1, 0}, {0, 1, 1}, {1, 0, 0}, {1, 0, 1}, {1, 1, 0}, {1, 1, 1}};
double[][] targets = {{1, 0}, {1, 0}, {1, 0}, {0, 1}, {1, 0}, {0, 1}, {0, 1}, {1, 0}};
for (int i = 0; i < 10000; i++) {
for (int j = 0; j < trainingInputs.length; j++) {
try {
nn.train(trainingInputs[j], targets[j]);
} catch (Exception e) {
e.printStackTrace();
}
}
}
double[] output;
try {
output = nn.feedForward(new double[]{0, 0, 0});
print(output);
output = nn.feedForward(new double[]{0, 0, 1});
print(output);
output = nn.feedForward(new double[]{0, 1, 0});
print(output);
output = nn.feedForward(new double[]{0, 1, 1});
print(output);
output = nn.feedForward(new double[]{1, 0, 0});
print(output);
output = nn.feedForward(new double[]{1, 0, 1});
print(output);
output = nn.feedForward(new double[]{1, 1, 0});
print(output);
output = nn.feedForward(new double[]{1, 1, 1});
print(output);
} catch (Exception e) {
e.printStackTrace();
}
} }
public class Matrix {
public double[][] data;
public Matrix(int row, int col) {
data = new double[row][col];
}
public Matrix(double[][] data) {
this.data = data;
}
public void randomize() {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] = new Random().nextDouble() * 2 - 1;
}
}
}
public Matrix transpose() {
Matrix result = new Matrix(data[0].length, data.length);
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
result.data[j][i] = data[i][j];
}
}
return result;
}
public static Matrix transpose(Matrix m) {
Matrix result = new Matrix(m.data[0].length, m.data.length);
for (int i = 0; i < m.data.length; i++) {
for (int j = 0; j < m.data[0].length; j++) {
result.data[j][i] = m.data[i][j];
}
}
return result;
}
public void add(double n) {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] += n;
}
}
}
public void subtract(double n) {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] -= n;
}
}
}
public void add(Matrix m) throws Exception {
if (!(data.length == m.data.length && data[0].length == m.data[0].length))
throw new Exception("columns and rows don't match!");
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] += m.data[i][j];
}
}
}
public void subtract(Matrix m) throws Exception {
if (!(data.length == m.data.length && data[0].length == m.data[0].length))
throw new Exception("columns and rows don't match!");
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] -= m.data[i][j];
}
}
}
public static Matrix add(Matrix m1, Matrix m2) throws Exception {
if (!(m1.data.length == m2.data.length && m1.data[0].length == m2.data[0].length))
throw new Exception("columns and rows don't match!");
Matrix result = new Matrix(m1.data.length, m1.data[0].length);
for (int i = 0; i < result.data.length; i++) {
for (int j = 0; j < result.data[0].length; j++) {
result.data[i][j] = m1.data[i][j] + m2.data[i][j];
}
}
return result;
}
public static Matrix subtract(Matrix m1, Matrix m2) throws Exception {
if (!(m1.data.length == m2.data.length && m1.data[0].length == m2.data[0].length))
throw new Exception("columns and rows don't match!");
Matrix result = new Matrix(m1.data.length, m1.data[0].length);
for (int i = 0; i < result.data.length; i++) {
for (int j = 0; j < result.data[0].length; j++) {
result.data[i][j] = m1.data[i][j] - m2.data[i][j];
}
}
return result;
}
public void multiply(double n) {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] *= n;
}
}
}
public void multiply(Matrix m) throws Exception {
if (!(data.length == m.data.length && data[0].length == m.data[0].length))
throw new Exception("columns and rows don't match!");
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] *= m.data[i][j];
}
}
}
public static Matrix multiply(Matrix m1, Matrix m2) throws Exception {
if (!(m1.data.length == m2.data.length && m1.data[0].length == m2.data[0].length))
throw new Exception("columns and rows don't match!");
Matrix result = new Matrix(m1.data.length, m1.data[0].length);
for (int i = 0; i < m1.data.length; i++) {
for (int j = 0; j < m1.data[0].length; j++) {
result.data[i][j] = m1.data[i][j] * m2.data[i][j];
}
}
return result;
}
public Matrix dot(Matrix m) throws Exception {
if (data[0].length != m.data.length)
throw new Exception("columns and rows don't match!");
Matrix result = new Matrix(data.length, m.data[0].length);
for (int i = 0; i < result.data.length; i++) {
for (int j = 0; j < result.data[0].length; j++) {
double sum = 0;
for (int k = 0; k < data[0].length; k++) {
sum += data[i][k] * m.data[k][j];
}
result.data[i][j] = sum;
}
}
return result;
}
public static Matrix dot(Matrix m1, Matrix m2) throws Exception {
if (m1.data[0].length != m2.data.length)
throw new Exception("columns and rows don't match!");
Matrix result = new Matrix(m1.data.length, m2.data[0].length);
for (int i = 0; i < result.data.length; i++) {
for (int j = 0; j < result.data[0].length; j++) {
double sum = 0;
for (int k = 0; k < m1.data[0].length; k++) {
sum += m1.data[i][k] * m2.data[k][j];
}
result.data[i][j] = sum;
}
}
return result;
}
public static interface Func {
public double method(double d);
}
public void map(Func f) {
for (int i = 0 ; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
data[i][j] = f.method(data[i][j]);
}
}
}
public static Matrix map(Matrix m, Func f) {
Matrix result = new Matrix(m.data.length, m.data[0].length);
for (int i = 0 ; i < m.data.length; i++) {
for (int j = 0; j < m.data[0].length; j++) {
result.data[i][j] = f.method(m.data[i][j]);
}
}
return result;
}
public static Matrix fromArray(double[] arr) {
Matrix res = new Matrix(arr.length, 1);
for (int i = 0; i < arr.length; i++) {
res.data[i][0] = arr[i];
}
return res;
}
public double[] toArray() {
double[] res = new double[data.length];
for (int i = 0; i < data.length; i++) {
res[i] = data[i][0];
}
return res;
}
public void print() {
for (int i = 0; i < data.length; i++) {
for (int j = 0; j < data[0].length; j++) {
System.out.print(data[i][j] + " ");
}
System.out.println();
}
}}
Upvotes: 1
Views: 243
Reputation: 1977
You have a couple of options for debugging this, and they can even be used together.
Add debug output for all of your calculations so you can see what exactly is causing the unexpected value. For example, you have...
public double sigmoid(double x) {
return 1 / (1 + Math.exp(-x));
}
but you could see what that is doing by making it...
public double sigmoid(double x) {
double sigmoid = 1 / (1 + Math.exp(-x));
System.out.println("1 / (1 + Math.exp(" + (-x) + ")) = " + sigmoid);
return sigmoid;
}
Do that anywhere you perform a calculation which could result in your unexpected value.
I suggest you output some debug info like this, then search the contents of the output for NaN. This would be even easier if you can get the output into a file and then open that file in a word processor to do a text search - if you're running on the command line you could do java MyApp > myapp_log.txt
then open myapp_log.txt
in a text editor to do a text search.
Or to make the output easier to handle, you could make your debug logic only output when it finds NaN, such as...
public double sigmoid(double x) {
double sigmoid = 1 / (1 + Math.exp(-x));
if(sigmoid == Double.NaN)
System.out.println("1 / (1 + Math.exp(" + (-x) + ")) = " + sigmoid);
return sigmoid;
}
Just remember to do it for everything you calculate, including your dsigmoid
, your add
, etc., everywhere you have any type of calculation. If you put enough of it everywhere then you will catch the problem and see lines output like "1 / (1 + Math.exp(NaN)) = NaN".
There are various things you can do with a debugger. You can run your program but step through it one line at a time and examine every variable and result as it is happening. Depending on the size of your matrices and how many times these functions get called, this may take a lot of effort.
Or you might be able to set a "watch" on a variable to have the program halt when a certain value equals NaN, then inspect the state of the program at that moment - I'm not sure if any debuggers for Java have this functionality though, as I've only done this type of debugging in C or assembly so you'd have to figure out if you have access to such a debugger.
Upvotes: 1