Reputation: 1130
mean_sqr = tf.reduce_mean(tf.pow(y_ - y, 2))
optimizer = tf.train.AdamOptimizer(LEARNING_RATE)
gradients, variables = zip(*optimizer.compute_gradients(mean_sqr))
opt = optimizer.apply_gradients(list(zip(gradients, variables)))
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for j in range(TRAINING_EPOCHS):
sess.run(opt, feed_dict={x: batch_xs, y_: batch_xs})
I don't clearly understand what compute_gradients returns? Does it return sum(dy/dx) for a given x values assigned by batch_xs, and update gradient in apply_gradients function such as :
theta <- theta - LEARNING_RATE*1/m*gradients?
Or does it already return average of gradients that is summed for each x values in a given batch such as sum(dy/dx)*1/m, m is defined as batch_size?
Upvotes: 8
Views: 4799
Reputation: 66805
compute_gradients(a,b) returns d[ sum a ]/db. So in your case this returns d mean_sq / d theta, where theta is set of all variables. There is no "dx" in this equation, you are not computing gradients wrt. inputs. So what happens with batch dimension? You remove it yourself in the definition of mean_sq:
mean_sqr = tf.reduce_mean(tf.pow(y_ - y, 2))
thus (I am assuming y is 1D for simplicity)
d[ mean_sqr ] / d theta = d[ 1/M SUM_i=1^M (pred(x_i), y_i)^2 ] / d theta
= 1/M SUM_i=1^M d[ (pred(x_i), y_i)^2 ] / d theta
so you are in control of whether it sums over batch, takes the mean or does something different, if you would define mean_sqr to use reduce_sum instead of a reduce_mean, gradients would be the sum over the batch and so on.
On the other hand apply_gradients simply "applies the gradients", the exact rule for application is optimiser dependent, for GradientDescentOptimizer it would be
theta <- theta - learning_rate * gradients(theta)
For Adam that you are using the equation is more complex of course.
Note however that tf.gradients is more like "backprop" than true gradient in mathematical sense - meaning that it depends on the graph dependencies and does not recognise dependences which are in "opposite" direction.
Upvotes: 6