aletherios
aletherios

Reputation: 1

Octave implementation of simple neural network with one float output

Some of you may be familiar with the simple handwritten digit classification NN that's part of Andrew Ng's ML course on coursera. To improve my understanding of the theory I'm trying to modify the implementation such that it outputs one float instead of 10 classification labels.

It has only one hidden layer. The code below is my attempt but the backprop produces wrong gradients. They don't match at all with analytical gradient comparisons. Because it only outputs one number, the activation of that output node is a simple linear function f(x) = x. Due to this, the error in the output unit is simply proportionally distributed backwards to the hidden layer based on weights. Perhaps someone can spot what I'm missing here, why the gradients are wrong.

m is the training data count, y is the correct output in a vector, X holds the training data, one row each.

%average cost over all training data
X = [ones(m, 1) X]; %insert bias column
a2 = tanh(X * Theta1'); %hidden layer activation
a2 = [ones(m, 1) a2];
a3 = a2 * Theta2'    %linear combination for final output
Cost = sum((y-a3).^2)/m; %mean squared error

backprop, iterating over m to accumulate gradients before averaging them:

for i = 1:m
    a1 = [1 X(i, :)]';  %get current training row
    z2 = Theta1 * a1;
    a2 = tanh(z2);
    a2 = [1; a2]; 
    z3 = Theta2 * a2;
    a3 = z3; %no activation function for final output
    
    delta3 = (a3-y(i))^2; %cost of current training row
    delta2 = Theta2' * delta3; %proportionally distribute error backwards, because no activation function was used?
    delta2 = delta2(2:end); %cut out bias element
    
    Theta2_grad = Theta2_grad + delta3 * a2'; %accumulate gradients
    Theta1_grad = Theta1_grad + delta2 * a1';

endfor

Theta1_grad = (1/m) * Theta1_grad; %average gradients
Theta2_grad = (1/m) * Theta2_grad;

Upvotes: 0

Views: 292

Answers (1)

Aleksandar Petrovic
Aleksandar Petrovic

Reputation: 11

I think you are wrong delta3=a3-y not (a3-y)**2

Upvotes: 0

Related Questions