Sri Ramana
Sri Ramana

Reputation: 121

Why does Tensorflow Reshape tf.reshape() break the flow of gradients?

I am creating a tf.Variable() and then create a simple function using that variable, then I flatten the original variable using tf.reshape() and then I take the tf.gradients() between the function and the flattened variable. Why does that return [None].

var = tf.Variable(np.ones((5,5)), dtype = tf.float32)
f = tf.reduce_sum(tf.reduce_sum(tf.square(var)))
var_f = tf.reshape(var, [-1])
print tf.gradients(f,var_f)

The above codeblock when executed returns [None]. Is this a bug? Please Help!

Upvotes: 8

Views: 4235

Answers (1)

Vijay Mariappan
Vijay Mariappan

Reputation: 17201

You are finding derivative of f with respect to var_f, but f is not a function of var_f but var instead. Thats why you are getting [None]. Now if you change the code to:

 var = tf.Variable(np.ones((5,5)), dtype = tf.float32)
 var_f = tf.reshape(var, [-1])
 f = tf.reduce_sum(tf.reduce_sum(tf.square(var_f)))
 grad = tf.gradients(f,var_f)
 print(grad)

your gradients will be defined:

tf.Tensor 'gradients_28/Square_32_grad/mul_1:0' shape=(25,) dtype=float32>

The visualization of the graphs for the following code is given below:

 var = tf.Variable(np.ones((5,5)), dtype = tf.float32, name='var')
 f = tf.reduce_sum(tf.reduce_sum(tf.square(var)), name='f')
 var_f = tf.reshape(var, [-1], name='var_f')
 grad_1 = tf.gradients(f,var_f, name='grad_1')
 grad_2 = tf.gradients(f,var, name='grad_2')

enter image description here

The derivative of grad_1 is not defined, while for grad_2 it's defined. The back-propagation graph (gradient graphs) of the two gradients are shown.

Upvotes: 6

Related Questions