Elcor
Elcor

Reputation: 143

TensorFlow Probability: how to do sample weighting with log_prob loss function?

I'm using TensorFlow probability to train a model whose output is a tfp.distributions.Independent object for probabilistic regression. My problem is that I'm unsure how to implement sample weighting in the negative log likelihood (NLL) loss function.

I have the following loss function which I believe does not use the sample_weight third argument:

class NLL(tf.keras.losses.Loss):
    ''' Custom keras loss/metric for negative log likelihood '''

    def __call__(self, y_true, y_pred, sample_weight=None):
        return -y_pred.log_prob(y_true)

With standard TensorFlow loss functions and a dataset containing (X, y, sample_weight) tuples, the use of sample_weight in the loss reductions summations is handled under the hood. How can I make the sum in y_pred.log_prob use the weights in the sample_weight tensor?

Upvotes: 3

Views: 949

Answers (1)

Elcor
Elcor

Reputation: 143

I found a solution to my problem as posted in this GitHub issue.

My problem was caused by the fact that my model outputs a tfp.Independent distribution, which means the log_prob is returned as a scalar sum over individual log_probs for each element of the tensor. This prevents weighting individual elements of the loss function. You can get the underlying tensor of log_prob values by accessing the .distribution attribute of the tfp.Independent object - this underlying distribution object treats each element of the loss as an independent random variable, rather than a single random variable with multiple values. By writing a loss function that inherits from tf.keras.losses.Loss, the resulting weighted tensor is implicitly reduced, returning the weighted mean of log_prob values rather than the sum, e.g.:

class NLL(tf.keras.losses.Loss):
    ''' Custom keras loss/metric for weighted negative log likelihood '''

    def __call__(self, y_true, y_pred, sample_weight=None):

        # This tensor is implicitly reduced by TensorFlow
        #     by taking the mean over all weighted elements
        return -y_pred.distribution.log_prob(y_true) * sample_weight

Upvotes: 5

Related Questions