Yash Kant
Yash Kant

Reputation: 498

How to update a sub-tensor inside a tensor in tensorflow?

I'm working with MNIST and I have a tensor of gradients with size [?,28,28,1] and I want to zero out a few of the [28,28,1] sub-tensors inside it, how should I accomplish this?

I know the indices (as a list) where I need to zero out the sub-tensors. I tried doing something like this (given below) but, scatter.update can only change variables not tensors. I also tried stacking up the required sub-tensors of zeroes and ones but couldn't build up the required result.

dy_dx, = tf.gradients(loss, x_adv) zeroes = tf.zeros(dy_dx[0].get_shape(), tf.float32) dy_dx = tf.scatter_update(dy_dx, indices, zeroes)

Thanks!

Upvotes: 2

Views: 455

Answers (1)

Stephen
Stephen

Reputation: 824

I'd suggest creating a TensorFlow constant with zeros at the locations you want to zero out and ones everywhere else. Then you could create an op that uses tf.multiply to do elementwise multiplication of the constant and dy_dx. Depending on the structure of your graph, you might need to feed the result to dy_dx in your next call to session.run; you can replace any Tensor with feed data, including variables and constants.

Incidentally, if you just want to apply dropout to the input layer you can use tf.layers.dropout

Upvotes: 1

Related Questions