Reputation: 3616
I want to filter a tensor by keeping 10% of the largest entries. Is there a Tensorflow function to do that? How would a possible implementation look like? I am looking for something that can handle tensors of shape [N,W,H,C]
and [N,W*H*C]
.
By filter I mean that the shape of the tensor remains the same but only the largest 10% are kept. Thus all entries become zero except the 10% largest.
Is that possible?
Upvotes: 3
Views: 716
Reputation: 59701
The correct way of doing this would be computing the 90 percentile, for example with tf.contrib.distributions.percentile
:
import tensorflow as tf
images = ... # [N, W, H, C]
n = tf.shape(images)[0]
images_flat = tf.reshape(images, [n, -1])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, 1]),
images, tf.zeros_like(images))
If you want to be ready for TensorFlow 2.x, where tf.contrib
will be removed, you can instead use TensorFlow Probability, which is where the percentile
function will be permanently in the future.
EDIT: If you want to do the filtering per channel, you can modify the code slightly like this:
import tensorflow as tf
images = ... # [N, W, H, C]
shape = tf.shape(images)
n, c = shape[0], shape[3]
images_flat = tf.reshape(images, [n, -1, c])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, c]),
images, tf.zeros_like(images))
Upvotes: 3
Reputation: 181
I've not found any built-in method yet. Try this workaround:
import numpy as np
import tensorflow as tf
def filter(tensor, ratio):
num_entries = tf.reduce_prod(tensor.shape)
num_to_keep = tf.cast(tf.multiply(ratio, tf.cast(num_entries, tf.float32)), tf.int32)
# Calculate threshold
x = tf.contrib.framework.sort(tf.reshape(tensor, [num_entries]))
threshold = x[-num_to_keep]
# Filter the tensor
mask = tf.cast(tf.greater_equal(tensor, threshold), tf.float32)
return tf.multiply(tensor, mask)
tensor = tf.constant(np.arange(40).reshape(2, 4, 5), dtype=tf.float32)
filtered_tensor = filter(tensor, 0.1)
# Print result
tf.InteractiveSession()
print(tensor.eval())
print(filtered_tensor.eval())
Upvotes: 0