Reputation: 51
I am trying to build a metric that is comparable to the metrics.PrecisionAtRecall class. Therefore, I've tried to build a custom metric by extending the keras.metrics.Metric class.
The original function is WSS = (TN + FN)/N − 1 + TP/(TP + FN) and this should be calculated at a certain recall value, for say 95%.
What I have until now is the following:
class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
def __init__(self, recall, name='wss_at_recall', **kwargs):
super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
self.wss = self.add_weight(name='wss', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
y_pred_neg = 1 - y_pred_pos
y_pos = tf.cast(backend.round(backend.clip(y_true, 0, 1)), tf.float32)
y_neg = 1 - y_pos
fn = backend.sum(y_neg * y_pred_pos)
tn = backend.sum(y_neg * y_pred_neg)
tp = backend.sum(y_pos * y_pred_pos)
n = len(y_true) # number of studies in batch
r = tp/(tp+fn+backend.epsilon()) # recall
self.wss.assign(((tn+fn)/n)-(1+r))
def result(self):
return self.wss
def reset_states(self):
# The state of the metric will be reset at the start of each epoch.
self.wss.assign(0.)
How can I calculate the WSS at a certain recall? I've seen the following in tensorflow's own git repository:
def __init__(self, recall, num_thresholds=200, name=None, dtype=None):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(PrecisionAtRecall, self).__init__(
value=recall,
num_thresholds=num_thresholds,
name=name,
dtype=dtype)
But that is't really possible through the keras.metrics.Metric class
Upvotes: 2
Views: 512
Reputation: 11651
If we follow the definition of the WSS@95 given by this paper :Reducing Workload in Systematic Review Preparation Using Automated Citation Classification, then we have
For the present work, we have fixed recall at 0.95 and therefore work saved over sampling at 95% recall (WSS@95%) is:
And you could define your update function by :
class WorkSavedOverSamplingAtRecall(tf.keras.metrics.Metric):
def __init__(self, recall, name='wss_at_recall', **kwargs):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
super(WorkSavedOverSamplingAtRecall, self).__init__(name=name, **kwargs)
self.wss = self.add_weight(name='wss', initializer='zeros')
def update_state(self, y_true, y_pred, sample_weight=None):
y_pred_pos = tf.cast(backend.round(backend.clip(y_pred, 0, 1)), tf.float32)
y_pred_neg = 1 - y_pred_pos
y_neg = 1 - y_pos
fn = backend.sum(y_neg * y_pred_pos)
tn = backend.sum(y_neg * y_pred_neg)
n = len(y_true) # number of studies in batch
self.wss.assign(((tn+fn)/n)-(1-self.recall))
One other solution would be to extend from the tensorflow class SensitivitySpecificityBase
and to implement the WSS as the PresicionAtRecall class is implemented.
By using this class, here's how the WSS is calculated :
The number of thresholds is use to match the given recall.
import tensorflow as tf
from tensorflow.python.keras.metrics import SensitivitySpecificityBase
class WorkSavedOverSamplingAtRecall(SensitivitySpecificityBase):
def __init__(self, recall, num_thresholds=200, name="wss_at_recall", dtype=None):
if recall < 0 or recall > 1:
raise ValueError('`recall` must be in the range [0, 1].')
self.recall = recall
self.num_thresholds = num_thresholds
super(WorkSavedOverSamplingAtRecall, self).__init__(
value=recall, num_thresholds=num_thresholds, name=name, dtype=dtype
)
def result(self):
recalls = tf.math.div_no_nan(
self.true_positives, self.true_positives + self.false_negatives
)
n = self.true_negatives + self.true_positives + self.false_negatives + self.false_positives
wss = tf.math.div_no_nan(
self.true_negatives+self.false_negatives, n
)
return self._find_max_under_constraint(
recalls, wss, tf.math.greater_equal
)
def get_config(self):
"""For serialization purposes"""
config = {'num_thresholds': self.num_thresholds, 'recall': self.recall}
base_config = super(WorkSavedOverSamplingAtRecall, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
Upvotes: 1