Reputation: 463
What does the "updates" argument do when called this way?
f_grad_shared = theano.function([x, mask, y], cost, updates=zgup + rg2up,
name='adadelta_f_grad_shared')
All the documentation I have seen about the "updates" argument in theano functions talk about pairs of the form (shared variables, expression used to update the shared variable). However, here there is only an expression so how to I know which shared variable is updated?
I guess the shared variable is somehow implicit but zgup
and rg2up
both depends on different shared variables:
zipped_grads = [theano.shared(p.get_value() * numpy_floatX(0.),
name='%s_grad' % k)
for k, p in tparams.iteritems()]
running_grads2 = [theano.shared(p.get_value() * numpy_floatX(0.),
name='%s_rgrad2' % k)
for k, p in tparams.iteritems()]
zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)]
rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2))
for rg2, g in zip(running_grads2, grads)]
This code comes from lstm.py
in http://deeplearning.net/tutorial/lstm.html
Thanks
Upvotes: 4
Views: 2251
Reputation: 34187
It is correct to think that updates
should be a list (or dictionary) of key value pairs where the key is a shared variable and the value is a symbolic expression describing how to update the corresponding shared variable.
These two lines create the pairs:
zgup = [(zg, g) for zg, g in zip(zipped_grads, grads)]
rg2up = [(rg2, 0.95 * rg2 + 0.05 * (g ** 2))
for rg2, g in zip(running_grads2, grads)]
zipped_grads
and running_grads2
were created in the previous lines are each just a list of shared variables. Here, those shared variables are linked to updates using the Python zip
function, which emits a list of pairs. In fact, the first of these lines could be replaced with
zgup = zip(zipped_grads, grads)
This code is quite complex because it is implementing the AdaDelta update mechanism. If you want to see how updates
works in a simpler setting, take a look at the basic stochastic gradient descent update in the Theano MLP tutorial.
updates = [
(param, param - learning_rate * gparam)
for param, gparam in zip(classifier.params, gparams)
]
Upvotes: 4