Reputation: 1909
I want to do evaluation of a classification Tensorflow model.
To compute the accuracy, I have the following code :
predictions = tf.argmax(logits, axis=-1, output_type=tf.int32)
accuracy = tf.metrics.accuracy(labels=label_ids, predictions=logits)
It work well in single label classification, but now I want to do multilabel classification, where my labels are Array of Integers instead of Integers.
Here is an example of label [0, 1, 1, 0, 1, 0]
that are stored in label_ids
, and an example of predictions [0.1, 0.8, 0.9, 0.1, 0.6, 0.2]
from the Tensor logits
What function should I use instead of argmax
to do so ? (My labels are arrays of 6 Integers with value of either 0 or 1)
If needed, we can suppose that there is a threshold of 0.5.
Upvotes: 1
Views: 1453
Reputation: 96
It is probably better to do this type of post-processing evaluation outside of tensorflow, where it is more natural to try several different thresholds.
If you want to do it in tensorflow, you can consider:
predictions = tf.math.greater(logits, tf.constant(0.5))
This will return a tensor of the original logits shape with True for all entries greater than 0.5. You can then calculate accuracy as before. This is suitable for cases where many labels can be simultaneously true for a given sample.
Upvotes: 2
Reputation: 13401
Use below code to caclutae accuracy in multiclass classification:
tf.argmax
will return the axis where y value is max
for both y_pred
and y_true
(actual y).
Further tf.equal
is used to find total number of matches (It returns True, False).
Convert the boolean into float(i.e. 0 or 1) and use tf.reduce_mean
to calculate the accuracy.
correct_mask = tf.equal(tf.argmax(y_pred,1), tf.argmax(y_true,1))
accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))
Edit
Example with data:
import numpy as np
y_pred = np.array([[0.1,0.5,0.4], [0.2,0.6,0.2], [0.9,0.05,0.05]])
y_true = np.array([[0,1,0],[0,0,1],[1,0,0]])
correct_mask = tf.equal(tf.argmax(y_pred,1), tf.argmax(y_true,1))
accuracy = tf.reduce_mean(tf.cast(correct_mask, tf.float32))
with tf.Session() as sess:
# print(sess.run([correct_mask]))
print(sess.run([accuracy]))
Output:
[0.6666667]
Upvotes: 1