reese0106
reese0106

Reputation: 2061

Tensorflow plot tf.metrics.precision_at_thresholds in Tensorboard through eval_metric_ops

tf.metrics.precision_at_thresholds() takes three arguments: labels, predictions, thresholds where thresholds is a a python list or tuple of thresholds between [0,1]. The function then returns "A float Tensor of shape [len(thresholds)]" which is problematic for automatically plotting eval_metric_ops to tensorboard (as I believe they are expected a scalar). The values will print to the console just fine, but I would also like to plot the values in tensorboard. Is there any adjustment that can be made to be able to plot the value in tensorboard?

Upvotes: 5

Views: 1366

Answers (2)

patzm
patzm

Reputation: 1082

I found it really strange that TensorFlow (as of 1.8) does not offer a summary function for metrics like tf.metrics.precision_at_thresholds (in general tf.metrics.*_at_thresholds). The following is a minimal working example:

def summarize_metrics(metrics_update_ops):
    for metric_op in metric_ops:
        shape = metric_op.shape.as_list()
        if shape:  # this is a metric created with any of tf.metrics.*_at_thresholds
            summary_components = tf.split(metric_op, shape[0])
            for i, summary_component in enumerate(summary_components):
                tf.summary.scalar(
                    name='{op_name}_{i}'.format(op_name=summary_components.name, i=i),
                    tensor=tf.squeeze(summary_component, axis=[0])
                )
        else:  # this already is a scalar metric operator
            tf.summary.scalar(name=summary_components.name, tensor=metric_op)

precision, precision_op = tf.metrics.precision_at_thresholds(labels=labels,
                                                             predictions=predictions,
                                                             thresholds=threshold)
summarize_metrics([precision_op])

The downside of this approach, in general, is that notion of whatever thresholds you used to create the metric in the first place, is lost when summarizing them. I came up with a slightly more complex, but easier to use solution that uses collections to store all metric update operators.

# Create a metric and let it add the vars and update operators to the specified collections
thresholds = [0.5, 0.7]
tf.metrics.recall_at_thresholds(
    labels=labels, predictions=predictions, thresholds=thresholds,
    metrics_collections='metrics_vars', metrics_update_ops='metrics_update_ops'
)

# Anywhere else call the summary method I provide in the Gist at the bottom [1]
# Because we provide a mapping of a scope pattern to the thresholds, we can
# assign them later
summarize_metrics(list_lookup={'recall_at_thresholds': thresholds})

The implementation in the Gist [1] below also supports options for formatting the sometimes cryptic names of the metrics nicely.

[1]: https://gist.github.com/patzm/961dcdcafbf3c253a056807c56604628

How this could look like: Imgur

Upvotes: 5

reese0106
reese0106

Reputation: 2061

My current approach is to create a separate function that just takes the mean of the first element in the list. However, I am expecting there is a more elegant solution than this:

def metric_fn(labels, predictions, threshold):
   precision, precision_op = tf.metrics.precision_at_thresholds(labels = labels,
                                                  predictions = predictions,
                                                  thresholds = threshold)
   mean, op = tf.metrics.mean(precision[0])

   return mean, op

Upvotes: 0

Related Questions