Grzenio
Grzenio

Reputation: 36649

Numerically stable way to calculate normal log likelihood

I am trying to calculate the normal log likelihood which is given by:

L = l1+l2+l3+...+ln,

where

lk = log(1/(sqrt(2*PI)*sigma_k))-0.5*e_k*e_k

Sigmas are around 0.2, and e_k are normally distributed with mean 0 and unit variance, so most of them are between -2 and 2;

I tried the following java code (sigma_k mentioned above = sigmas.get(k)*Math.sqrt(dt)):

private double new1(List<Double> residuals, List<Double> sigmas, double dt) {
    double a = 0;
    for(int i=0; i<sigmas.size(); i++) {
        a += Math.log(1.0/(Math.sqrt(2*Math.PI*dt)*sigmas.get(i)));
    }
    double b = 0;
    for(int i=0; i<residuals.size(); i++) {
        b += residuals.get(i)*residuals.get(i);
    }
    return a-0.5*b;
}

but the theoretical maximum is lower than the maximum I got by doing numerical optimisation, so I have some suspicions that my method is suboptimal.

Upvotes: 1

Views: 1643

Answers (2)

Joop Eggen
Joop Eggen

Reputation: 109567

Remark: In some areas probability/statistics are calculated without taking the log, for instance in linguistic frequencies of combinations.

The following simplifies, becoming less stable, but afterwards one convert it back to a sum of logs or so.


double a = 0;
for(int i=0; i<sigmas.size(); i++) {
    a += Math.log(1.0/(Math.sqrt(2*Math.PI*dt)*sigmas.get(i)));
}

log(x) + log(y) = log(x*y)

double a = 1.0;
for(int i=0; i<sigmas.size(); i++) {
    a *= 1.0/(Math.sqrt(2*Math.PI*dt)*sigmas.get(i));
}
a = Math.log(a);

(1/x)*(1/y) = 1/(x*y)

double a = 1.0;
for(int i=0; i<sigmas.size(); i++) {
    a *= Math.sqrt(2*Math.PI*dt)*sigmas.get(i);
}
a = Math.log(1.0/a);

sqrt(x)^n = (x^0.5)^n = x^(n/2)

static import Math.*;

double a = pow(2*PI*dt, sigmas.size() / 2.0);
for(int i=0; i<sigmas.size(); i++) {
    a *= sigmas.get(i);
}
a = -log(a);

Upvotes: 1

zch
zch

Reputation: 15278

I'm not sure if it will greatly improve numerical stability, but your equations can be simplified using logarithm laws:

log(a*b) = log(a) + log(b)
log(1/a) = -log(a)
log(sqrt(a)) = log(a)/2

so you have:

lk = -log(2*pi)/2 - log(sigma_k) - 0.5*e_k*e_k
   = -log(2*pi)/2 - log(dt)/2 - log(sigmas.get(k)) - 0.5*e_k*e_k
   = -log(2*pi*dt)/2 - log(sigmas.get(k)) - 0.5*e_k*e_k

First is constant, so in the first loop you only need to do a -= log(sigmas.get(k)).

Also, it look suspicious, that first loop is to sigmas.size() and the second to residuals.size() while the equation suggests, that they should have the same length.

Upvotes: 0

Related Questions