Rodsnjr
Rodsnjr

Reputation: 127

Tensorflow, change Tensor values given a condition

I am translating a numpy code to Tensorflow.

It has the following line:

netout[..., 5:] *= netout[..., 5:] > obj_threshold

This is not the same Tensorflow syntax, I'm having trouble finding the functions with the same behavior.

Firstly I tried:

netout[..., 5:] * netout[..., 5:] > obj_threshold

But the return is a boolean only Tensor. In this case I want all values below the obj_threshold to be 0.

Upvotes: 3

Views: 2456

Answers (1)

javidcf
javidcf

Reputation: 59681

If you just wanted to make 0 all values below obj_threshold you could just do:

netout = tf.where(netout > obj_threshold, netout, tf.zeros_like(netout))

Or:

netout = netout * tf.cast(netout > obj_threshold, netout.dtype)

However, your case is a bit more complicated, because you only want the change to affect part of the tensor. So one thing you can do is make a boolean mask that is True for values over obj_threshold or values where the last index is below 5.

mask = (netout > obj_threshold) | (tf.range(tf.shape(netout)[-1]) < 5)

Then you can use that with any of the previous methods:

netout = tf.where(mask, netout, tf.zeros_like(netout))

netout = netout * tf.cast(mask, netout.dtype)

Upvotes: 6

Related Questions