Reputation: 3944
I'm running a gradient checking in Python as following:
class NeuralNetwork:
def gradient_checking(self):
m = 5
theta1 = debuginitializeweights(self.size[1], self.size[0])
theta2 = debuginitializeweights(self.size[2], self.size[1])
thetas = vectorize([theta1, theta2], self.size)
X = debuginitializeweights(m, self.size[0] - 1)
y = 1 + np.mod(np.array(range(m)), 3)
return scipy.optimize.check_grad(self.cost, self.grad, thetas, [X, y])
where the signature of class methods are:
def cost(self, thetas, X, label):
def grad(self, thetas, X, label):
However when running the gradient checking it is saying that
File "/home/andrey/Data/Hacking/ML/Coursera_Andrew/neuralnetwork.py", line 142, in gradient_checking
return check_grad(self.cost, self.grad, thetas, [X, y])
File "/usr/local/lib/python2.7/dist-packages/scipy/optimize/optimize.py", line 656, in check_grad
return sqrt(sum((grad(x0, *args) -
TypeError: grad() takes exactly 4 arguments (3 given)
How can I fix this error?
Upvotes: 0
Views: 85
Reputation: 3944
Turns out this is simply optional argument problem. Here is a quick example:
from scipy.optimize import check_grad
class NeuralNetwork:
def cost(self, x, A, b):
return x[0] ** 2 - 0.5 * x[1] ** 3 + A * b
def grad(self, x, A, b):
return [2 * x[0] + A, -1.5 * x[1]**2 +b]
a = NeuralNetwork()
print a.cost([1.5, -1.5], 10, 1)
print check_grad(a.cost, a.grad, [1.5, -1.5], 10, 1)
Previously I did:
check_grad(a.cost, a.grad, [1.5, -1.5], (10, 1))
That's why it keeps missing arguments.
Upvotes: 0
Reputation: 9633
The stack trace tells you exactly what you need to know:
TypeError: grad() takes exactly 4 arguments (3 given)
Your grad
signature reflects the 4 argument requirement:
def grad(self, thetas, X, label):
I see you're attempting to use tuple unpacking in your call to grad()
:
return sqrt(sum((grad(x0, *args))))
The implicit self
being passed when grad()
is called will take the self
position in the argument list, x0
will take the thetas
position, leaving X
and label
to be filled by *args
. Try printing args
or examining it with PDB to confirm that it contains two items. Since you're not getting a ValueError
from trying to unpack a non-iterable, it probably is of the right type. It sounds like it probably doesn't have both items in it that you expected.
Upvotes: 1