tinkerbell
tinkerbell

Reputation: 117

gather values from 2dim tensor in tensorflow

Hi tensorflow beginner here... I'm trying to get the value of a certain elements in an 2 dim tensor, in my case class scores from a probability matrix.

The probability matrix is (1000,81) with batchsize 1000 and number of classes 81. ClassIDs is (1000,) and contains the index for the highest class score for each sample. How do I get the corresponding class score from the probability matrix using tf.gather?

class_ids = tf.cast(tf.argmax(probs, axis=1), tf.int32)  
class_scores = tf.gather_nd(probs,class_ids)

class_scores should be a tensor of shape (1000,) containing the highest class_score for each sample.

Right now I'm using a workaround that looks like this:

class_score_count = []
for i in range(probs.shape[0]):
    prob = probs[i,:]
    class_score = prob[class_ids[i]]
    class_score_count.append(class_score)
class_scores = tf.stack(class_score_count, axis=0)

Thanks for the help!

Upvotes: 3

Views: 592

Answers (3)

Vladimir Panteleev
Vladimir Panteleev

Reputation: 25187

I think this is what the batch_dims argument for tf.gather is for.

Upvotes: 0

javidcf
javidcf

Reputation: 59731

You can do it with tf.gather_nd like this:

class_ids = tf.cast(tf.argmax(probs, axis=1), tf.int32)
# If shape is not dynamic you can use probs.shape[0].value instead of tf.shape(probs)[0]
row_ids = tf.range(tf.shape(probs)[0], dtype=tf.int32)
idx = tf.stack([row_ids, class_ids], axis=1)
class_scores = tf.gather_nd(probs, idx)

You could also just use tf.reduce_max, even though it would actually compute the maximum again it may not be much slower if your data is not too big:

class_scores = tf.reduce_max(probs, axis=1)

Upvotes: 2

Jai
Jai

Reputation: 3300

  • you need to run the tensor class_ids to get the values
  • the values will be a bumpy array
  • you can access numpy array normally by a loop
  • you have to do something like this : predictions = sess.run(tf.argmax(probs, 1), feed_dict={x: X_data})
  • predictions variable has all the information you need
  • tensorflow only returns those tensor values which you run explicitly

Upvotes: 0

Related Questions