Reputation: 127
I am translating a numpy code to Tensorflow.
It has the following line:
netout[..., 5:] *= netout[..., 5:] > obj_threshold
This is not the same Tensorflow syntax, I'm having trouble finding the functions with the same behavior.
Firstly I tried:
netout[..., 5:] * netout[..., 5:] > obj_threshold
But the return is a boolean only Tensor.
In this case I want all values below the obj_threshold
to be 0.
Upvotes: 3
Views: 2456
Reputation: 59681
If you just wanted to make 0 all values below obj_threshold
you could just do:
netout = tf.where(netout > obj_threshold, netout, tf.zeros_like(netout))
Or:
netout = netout * tf.cast(netout > obj_threshold, netout.dtype)
However, your case is a bit more complicated, because you only want the change to affect part of the tensor. So one thing you can do is make a boolean mask that is True
for values over obj_threshold
or values where the last index is below 5.
mask = (netout > obj_threshold) | (tf.range(tf.shape(netout)[-1]) < 5)
Then you can use that with any of the previous methods:
netout = tf.where(mask, netout, tf.zeros_like(netout))
netout = netout * tf.cast(mask, netout.dtype)
Upvotes: 6