rvdinter
rvdinter

Reputation: 51

Custom Keras Metrics Class -> Metric at a certain recall value

I am trying to build a metric that is comparable to the metrics.PrecisionAtRecall class. Therefore, I've tried to build a custom metric by extending the keras.metrics.Metric class.

The original function is WSS = (TN + FN)/N − 1 + TP/(TP + FN) and this should be calculated at a certain recall value, for say 95%.

What I have until now is the following:

class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
def __init__(self, recall, name='wss_at_recall', **kwargs):
    super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
    self.wss = self.add_weight(name='wss', initializer='zeros')

def update_state(self, y_true, y_pred, sample_weight=None):
    y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
    y_pred_neg = 1 - y_pred_pos
    y_pos = tf.cast(backend.round(backend.clip(y_true, 0, 1)), tf.float32)
    y_neg = 1 - y_pos
    
    fn = backend.sum(y_neg * y_pred_pos)
    tn = backend.sum(y_neg * y_pred_neg)
    tp = backend.sum(y_pos * y_pred_pos)
    n = len(y_true) # number of studies in batch
    r = tp/(tp+fn+backend.epsilon()) # recall
    self.wss.assign(((tn+fn)/n)-(1+r))

def result(self):
    return self.wss

def reset_states(self):
    # The state of the metric will be reset at the start of each epoch.
    self.wss.assign(0.)

How can I calculate the WSS at a certain recall? I've seen the following in tensorflow's own git repository:

def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
if recall < 0 or recall > 1:
  raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(PrecisionAtRecall, self).__init__(
    value=recall,
    num_thresholds=num_thresholds,
    name=name,
    dtype=dtype)

But that is't really possible through the keras.metrics.Metric class

Upvotes: 2

Views: 512

Answers (1)

Lescurel
Lescurel

Reputation: 11651

If we follow the definition of the WSS@95 given by this paper :Reducing Workload in Systematic Review Preparation Using Automated Citation Classification, then we have

For the present work, we have fixed recall at 0.95 and therefore work saved over sampling at 95% recall (WSS@95%) is:

WSS@95 = (TN+FN)/N - 0.05

And you could define your update function by :

class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
    def __init__(self, recall, name='wss_at_recall', **kwargs):
        if recall < 0 or recall > 1:
            raise ValueError('`recall` must be in the range [0, 1].')
        self.recall = recall
        super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
        self.wss = self.add_weight(name='wss', initializer='zeros')

    def update_state(self, y_true, y_pred, sample_weight=None):
        y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
        y_pred_neg = 1 - y_pred_pos
        y_neg = 1 - y_pos
    
        fn = backend.sum(y_neg * y_pred_pos)
        tn = backend.sum(y_neg * y_pred_neg)
        n = len(y_true) # number of studies in batch
        self.wss.assign(((tn+fn)/n)-(1-self.recall)) 

One other solution would be to extend from the tensorflow class SensitivitySpecificityBase and to implement the WSS as the PresicionAtRecall class is implemented.

By using this class, here's how the WSS is calculated :

  • Compute the recall at all the thresholds (200 thresholds by default).
  • Find the index of the threshold where the recall is closest to the requested value. (0.95 in that case).
  • Compute the WSS at that index.

The number of thresholds is use to match the given recall.

import tensorflow as tf
from tensorflow.python.keras.metrics import SensitivitySpecificityBase


class WorkSavedOverSamplingAtRecall(SensitivitySpecificityBase):
    def __init__(self, recall, num_thresholds=200, name="wss_at_recall", dtype=None):
        if recall < 0 or recall > 1:
            raise ValueError('`recall` must be in the range [0, 1].')
        self.recall = recall
        self.num_thresholds = num_thresholds
        super(WorkSavedOverSamplingAtRecall, self).__init__(
            value=recall, num_thresholds=num_thresholds, name=name, dtype=dtype
        )

    def result(self):
        recalls = tf.math.div_no_nan(
            self.true_positives, self.true_positives + self.false_negatives
        )
        n = self.true_negatives + self.true_positives + self.false_negatives + self.false_positives
        wss = tf.math.div_no_nan(
            self.true_negatives+self.false_negatives, n
        )
        return self._find_max_under_constraint(
            recalls, wss, tf.math.greater_equal
        )

    def get_config(self):
        """For serialization purposes"""
        config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
        base_config = super(WorkSavedOverSamplingAtRecall, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

Upvotes: 1

Related Questions