J.Doe
J.Doe

Reputation: 3649

How to get the count of an element in a tensor in TensorFlow?

I want to get the count of an element in a tensor, for example, t = [1, 2, 0, 0, 0, 0] (t is a tensor). I can get the amount 4 of zeros by calling t.count(0) in Python, but in TensorFlow, I can't find any functions to do this. How can I get the count of zeros?

Upvotes: 15

Views: 30247

Answers (3)

Daniel Slater
Daniel Slater

Reputation: 4143

There isn't a built in count method in TensorFlow right now. But you could do it using the existing tools in a method like so:

def tf_count(t, val):
    elements_equal_to_value = tf.equal(t, val)
    as_ints = tf.cast(elements_equal_to_value, tf.int32)
    count = tf.reduce_sum(as_ints)
    return count

Upvotes: 12

Salvador Dali
Salvador Dali

Reputation: 222889

To count just a specific element you can create a boolean mask, convert it to int and sum it up:

import tensorflow as tf

X = tf.constant([6, 3, 3, 3, 0, 1, 3, 6, 7])
res = tf.reduce_sum(tf.cast(tf.equal(X, 3), tf.int32))
with tf.Session() as sess:
    print sess.run(res)

Also you can count every element in the list/tensor using tf.unique_with_counts;

import tensorflow as tf

X = tf.constant([6, 3, 3, 3, 0, 1, 3, 6, 7])
y, idx, cnts = tf.unique_with_counts(X)
with tf.Session() as sess:
    a, _, b = sess.run([y, idx, cnts])
    print a
    print b

Upvotes: 7

Eli Bixby
Eli Bixby

Reputation: 1178

An addition to Slater's answer above. If you want to get the count of all the elements, you can use one_hot and reduce_sum to avoid any looping within python. For example, the code-snippet below returns a vocab, ordered by occurrences within a word_tensor.

def build_vocab(word_tensor, vocab_size): 
  unique, idx = tf.unique(word_tensor)
  counts_one_hot = tf.one_hot(
      idx, 
      tf.shape(unique)[0],
      dtype=tf.int32
  )
  counts = tf.reduce_sum(counts_one_hot, 0)
  _, indices = tf.nn.top_k(counts, k=vocab_size)
  return tf.gather(unique, indices)

EDIT: After a little experimentation, I discovered it's pretty easy for the one_hot tensor to blow up beyond TF's maximum tensor size. It's likely more efficient (if a little less elegant) to replace the counts call with something like this:

counts = tf.foldl(
  lambda counts, item: counts + tf.one_hot(
      item, tf.shape(unique)[0], dtype=tf.int32),
  idx,
  initializer=tf.zeros_like(unique, dtype=tf.int32),
  back_prop=False
)

Upvotes: 2

Related Questions