Reputation: 91
I have 2 2d tensors of the same size (let's say shape =[80,90]). I want to get the value of the variable whose absolute value is bigger elementwise.
In python I would do something like this:
mask = np.abs(a)>np.abs(b)
c = a*mask + b*~mask
How do I do this in tensorflow while still being able to calculate gradients?
I know I can do this:
mask = tf.abs(a) > tf.abs(b)
but then:
c= tf.cast(mask, tf.float32) * a + tf.cast(~mask, tf.float32) * b
doesn't pass gradients as the cast operation doesn't pass gradients.
Upvotes: 0
Views: 575
Reputation: 2004
@coldspped answers work fine. But if you want a general purpose method that can extract element based on a random mask, you can use tf.where
api. For your problem, the answer is as following:
mask = tf.abs(a) - tf.abs(b)
c = tf.where(mask, a, b)
Upvotes: 1