jasmine
jasmine

Reputation: 231

autodiff for max()

I am doing optimization using maximum likelihood estimation, and when I am trying to get the standard errors of estimates using hessian matrix, I get non-invertible/singular hessian warning.

After I inspect the hessian, I find that there is a column/row of all zeros. What does this tell me? Should I worry about identifiability of my model?

jac [ 1.19782970e+00 -4.47698715e+01  1.07446142e+01 -0.00000000e+00
 -0.00000000e+00  1.49231202e-07 -0.00000000e+00  1.07650748e-08]
hes = jnp.outer(jac, jac) 
[[ 1.43479598e+00 -5.36266815e+01  1.28702180e+01 -0.00000000e+00
  -0.00000000e+00  1.78753565e-07 -0.00000000e+00  1.28947263e-08]
 [-5.36266815e+01  2.00434139e+03 -4.81034997e+02  0.00000000e+00
   0.00000000e+00 -6.68106171e-06  0.00000000e+00 -4.81951017e-07]
 [ 1.28702180e+01 -4.81034997e+02  1.15446735e+02 -0.00000000e+00
  -0.00000000e+00  1.60343169e-06 -0.00000000e+00  1.15666576e-07]
 [-0.00000000e+00  0.00000000e+00 -0.00000000e+00  0.00000000e+00
   0.00000000e+00 -0.00000000e+00  0.00000000e+00 -0.00000000e+00]
 [-0.00000000e+00  0.00000000e+00 -0.00000000e+00  0.00000000e+00
   0.00000000e+00 -0.00000000e+00  0.00000000e+00 -0.00000000e+00]
 [ 1.78753565e-07 -6.68106171e-06  1.60343169e-06 -0.00000000e+00
  -0.00000000e+00  2.22699515e-14 -0.00000000e+00  1.60648505e-15]
 [-0.00000000e+00  0.00000000e+00 -0.00000000e+00  0.00000000e+00
   0.00000000e+00 -0.00000000e+00  0.00000000e+00 -0.00000000e+00]
 [ 1.28947263e-08 -4.81951017e-07  1.15666576e-07 -0.00000000e+00
  -0.00000000e+00  1.60648505e-15 -0.00000000e+00  1.15886836e-16]]

I also have read in the posts that singularity might be caused by precision error, and add 𝜆𝐼 will fix the problem. But how can I figure it out if my problem is caused by precision in python (I am using float64, and autodifferentiation from jax for hessian)?

Code for reproduce:

from joblib import Parallel, delayed
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

class Model:
    def __init__(self, data):
        self.data = data

    def get_p_l0(self,phi,id,time):
        df=self.data
        subdf = df[(df[:,0]==id)&(df[:,1]>=1)&(df[:,1]<=time)]

        weights=jnp.array([phi ** i for i in range(0, int(time))])[::-1]
        dem = weights.sum()

        nom1 = (weights * subdf[:,2]).sum()
        nom2 = (weights * subdf[:,3]).sum()
        nom3 = (weights * subdf[:,4]).sum()
        b1 = nom1 / dem
        b2 = nom2 / dem
        b3 = nom3 / dem

        a1,a2,a3 = self.get_attraction(b1, b2, b3)
        return [a1,a2,a3]

    def get_attraction(self, belief1, belief2, belief3):
        a1=20 * belief1 + 12 * belief2 + 6 * belief3
        a2=12 * belief1 + 24 * belief2 + 18 * belief3
        a3 = 23 * belief1 + 0 * belief2 + 40 * belief3
        return a1,a2,a3

    def get_prob(self, a1, a2, a3, lamda):
        m1 = lamda * a1
        m2 = lamda * a2
        m3 = lamda * a3

        m=jnp.array([m1,m2,m3]).T
        c=jnp.amax(m,axis=1)
        aa1 = jnp.exp(m1-c)
        aa2 = jnp.exp(m2-c)
        aa3 = jnp.exp(m3-c)
        logsumexp = c+jnp.log(jnp.array([aa1,aa2,aa3]).sum(axis=0))
        p1_ = jnp.exp(m1-logsumexp)
        p2_ = jnp.exp(m2-logsumexp)
        p3_ = jnp.exp(m3-logsumexp)

        return p1_, p2_, p3_

    def get_sub_prob(self,attractions):
        m = max(attractions)
        m_actions=[i+1 for i, j in enumerate(attractions) if j == m]

        if m_actions==[1]:
            p1_l1_l0,p2_l1_l0,p3_l1_l0 = 1.0, 0.0, 0.0
        elif m_actions==[2] or m_actions==[1,2]:
            p1_l1_l0, p2_l1_l0, p3_l1_l0 = 0.0, 1.0, 0.0
        else:
            p1_l1_l0, p2_l1_l0, p3_l1_l0 = 0.0, 0.0, 1.0

        return p1_l1_l0, p2_l1_l0, p3_l1_l0

    def get_prob_t_l1(self,p_hat,p1_l1_l0, p2_l1_l0, p3_l1_l0,):
        b1_t_a1 = p_hat+(1-p_hat)*p1_l1_l0
        b2_t_a1 = (1 - p_hat) * p2_l1_l0
        b3_t_a1 = (1 - p_hat) * p3_l1_l0

        a1_t_a1,a2_t_a1,a3_t_a1 = self.get_attraction(b1_t_a1,b2_t_a1,b3_t_a1)

        b1_t_a2 = (1 - p_hat) * p1_l1_l0
        b2_t_a2 = p_hat + (1 - p_hat) * p2_l1_l0
        b3_t_a2 = (1 - p_hat) * p3_l1_l0

        a1_t_a2, a2_t_a2, a3_t_a2 = self.get_attraction(b1_t_a2, b2_t_a2, b3_t_a2)

        b1_t_a3 = (1 - p_hat) * p1_l1_l0
        b2_t_a3 = (1 - p_hat) * p2_l1_l0
        b3_t_a3 = p_hat + (1 - p_hat) * p3_l1_l0

        a1_t_a3, a2_t_a3, a3_t_a3 = self.get_attraction(b1_t_a3, b2_t_a3, b3_t_a3)

        return a1_t_a1, a2_t_a2, a3_t_a3

    def get_l1_p_t_id(self, id, phi_hat, time):
        p_hat = 0.5

        a_l1_l0 = self.get_p_l0(phi=phi_hat, id=id, time=time)
        p1_l1_l0, p2_l1_l0, p3_l1_l0 = self.get_sub_prob(a_l1_l0)

        eu1, eu2, eu3 = self.get_prob_t_l1(p_hat,p1_l1_l0, p2_l1_l0, p3_l1_l0)

        return [eu1, eu2, eu3]

    def likelihood(self, params):
        p1 = params[0]
        lamda = params[1]
        phi = params[2]
        phi_hat = params[3]
        #k = params[4]

        df=self.data

        num_cores=1
        a_t_l0 = Parallel(n_jobs=num_cores)(
            delayed(self.get_p_l0)(id=row[0], time=row[1], phi=phi) for row in df)

        a_t_l0=jnp.array(a_t_l0)
        p1_bl, p2_bl, p3_bl = self.get_prob(a_t_l0[:,0], a_t_l0[:,1], a_t_l0[:,2], lamda)

        a_t_l1 = Parallel(n_jobs=num_cores)(delayed(self.get_l1_p_t_id)(id=row[0], time=row[1],phi_hat=phi_hat
                                                                        ) for row in df)
        a_t_l1=jnp.array(a_t_l1)
        p1_l1_t, p2_l1_t, p3_l1_t = self.get_prob(a_t_l1[:,0], a_t_l1[:,1], a_t_l1[:,2], lamda)

        f_l0 = p1_bl * df[:,5] + p2_bl * df[:,6] + p3_bl * df[:,7]
        f_l1 = p1_l1_t * df[:,5] + p2_l1_t * df[:,6] + p3_l1_t * df[:,7]

        num_ids = jnp.unique(df[:,0]).shape[0]
        ff_l0 = jnp.array([jnp.prod(i) for i in jnp.split(f_l0,num_ids)])
        ff_l1 = jnp.array([jnp.prod(i) for i in jnp.split(f_l1,num_ids)])

        pp = (1-p1) * ff_l0 + p1 * ff_l1

        li = jnp.log(pp)
        ll = jnp.sum(li)

        return -ll

    def get_se(self, params):
        hes = jax.hessian(self.likelihood)(params)
        print("hes",hes,)
        inv_hessian = jnp.linalg.inv(hes)
        print("inv",inv_hessian)
        diag_inv_hessian = jnp.diag(inv_hessian)
        print("diag",diag_inv_hessian)
        se = jnp.sqrt(diag_inv_hessian)
        return se

    def test(self, x0):
        print(self.likelihood(x0))

from jax import random
key = random.PRNGKey(1)
id = jnp.append(jnp.ones((100,)), jnp.ones((100,))*2)
choice = random.randint(key,(200,),1,4)
y1 = jnp.where(choice==1, 1, 0)
y2 = jnp.where(choice==2, 1, 0)
y3 = jnp.where(choice==3, 1, 0)
s = random.dirichlet(key, alpha=jnp.ones((3,)), shape=(200,), dtype=jnp.float64)
period=jnp.array([jnp.arange(1, 101)]*2).flatten()
data=jnp.column_stack((id,period,s,y1,y2,y3))


# Test
x0 = jnp.array([ 0.8,  2.745e-01 , 2.584e-01, 0.8],dtype=jnp.float64)
se = Model(data).get_se(x0)

hes [[ 5.00000000e+01 -4.06378551e-13  4.17754693e-14  0.00000000e+00]
 [-3.40315672e-13  1.23290937e+03 -6.83135224e+02  0.00000000e+00]
 [ 4.79170617e-14 -6.83135224e+02  8.59001000e+01  0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]

I have max() here, could it be the problem?

When I use Nelder-Mead for minimization, phi_hat gets updated from callback. If I use L-BFGS-B, phi_hat stuck at the initial value. So could it be L-BFGS-B using jacobian for optimization but Nelder-Mead does not? and so I should not use autodiff or numerical differentiation to get hessian?

Update: Drop the max() related part, now everything works.

Upvotes: 0

Views: 101

Answers (1)

jakevdp
jakevdp

Reputation: 86330

You're not computing the hessian, you're computing the outer product of the jacobian. An outer product of a vector with itself will always be singular.

Upvotes: 0

Related Questions