Gilfoyle
Gilfoyle

Reputation: 3616

Tensorflow: Keep 10% of the largest entries of a tensor

I want to filter a tensor by keeping 10% of the largest entries. Is there a Tensorflow function to do that? How would a possible implementation look like? I am looking for something that can handle tensors of shape [N,W,H,C] and [N,W*H*C].

By filter I mean that the shape of the tensor remains the same but only the largest 10% are kept. Thus all entries become zero except the 10% largest.

Is that possible?

Upvotes: 3

Views: 716

Answers (2)

javidcf
javidcf

Reputation: 59701

The correct way of doing this would be computing the 90 percentile, for example with tf.contrib.distributions.percentile:

import tensorflow as tf

images = ... # [N, W, H, C]
n = tf.shape(images)[0]
images_flat = tf.reshape(images, [n, -1])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, 1]),
                        images, tf.zeros_like(images))

If you want to be ready for TensorFlow 2.x, where tf.contrib will be removed, you can instead use TensorFlow Probability, which is where the percentile function will be permanently in the future.

EDIT: If you want to do the filtering per channel, you can modify the code slightly like this:

import tensorflow as tf

images = ... # [N, W, H, C]
shape = tf.shape(images)
n, c = shape[0], shape[3]
images_flat = tf.reshape(images, [n, -1, c])
p = tf.contrib.distributions.percentile(images_flat, 90, axis=1, interpolation='higher')
images_top10 = tf.where(images >= tf.reshape(p, [n, 1, 1, c]),
                        images, tf.zeros_like(images))

Upvotes: 3

William
William

Reputation: 181

I've not found any built-in method yet. Try this workaround:

import numpy as np
import tensorflow as tf

def filter(tensor, ratio):
  num_entries = tf.reduce_prod(tensor.shape)
  num_to_keep = tf.cast(tf.multiply(ratio, tf.cast(num_entries, tf.float32)), tf.int32)
  # Calculate threshold
  x = tf.contrib.framework.sort(tf.reshape(tensor, [num_entries]))
  threshold = x[-num_to_keep]
  # Filter the tensor
  mask = tf.cast(tf.greater_equal(tensor, threshold), tf.float32)
  return tf.multiply(tensor, mask)

tensor = tf.constant(np.arange(40).reshape(2, 4, 5), dtype=tf.float32)
filtered_tensor = filter(tensor, 0.1)
# Print result
tf.InteractiveSession()
print(tensor.eval())
print(filtered_tensor.eval())

Upvotes: 0

Related Questions