Pedro Rojas
Pedro Rojas

Reputation: 1

Multinomial logit in Julia

I want to perform the multinomial logit regression by my own in Julia. I am following the formulas I have seen in here. However, I have issues in specifying the Hessian matrix.

My code is the following one:

function mlogit(xvars :: Matrix{Float64}, depvar :: Vector{Int64}; tol = 1e-6)

    N = size(depvar, 1) # Length of the data set
    m = length(unique(depvar)) # m categories
    k = size(xvars, 2) # k regressors

    function probability(θ)
        pt = zeros(N, m)
        for j in 1:m
            pt[:,j] = exp.(xvars * θ[:,j])*(j > 1) + (j == 1)*ones(N)
        end
        ptotal = pt * ones(m)

        p = zeros(N)
        for i in 1:N
            outcome = depvar[i]
            p[i] = pt[i, outcome] / ptotal[i]
        end
        return p, pt ./ ptotal
    end

    function likelihood(θ_flat)
        θ_matrix = reshape(θ_flat, k, m)
        p, _ = probability(θ_matrix)
        return -ones(N)' * log.(p)
    end

    function g!(G, θ_flat)
        θ_matrix = reshape(θ_flat, k, m)
        _, prob = probability(θ_matrix)

        G[:] = zeros(m*k)
        for i in 1:N
            y = zeros(m)
            y[depvar[i]] = 1 

            A = (y - prob[i, :])
            G[:] -= kron(A, xvars[i, :])
        end
    end

    function h!(H, θ_flat)
        θ_matrix = reshape(θ_flat, k, m)
        _, prob = probability(θ_matrix)
    
        H[:, :] = zeros(m*k, m*k)
        for i in 1:N
            A = Diagonal(prob[i, :]) - prob[i, :] * prob[i, :]'
            B = xvars[i, :] * xvars[i, :]'
            H[:, :] += kron(A, B)
        end
    end
    
    Initialguess = zeros(Float64, k*m)
    results = optimize(likelihood, g!, h!, Initialguess, Newton(), Optim.Options(iterations = 300, show_trace = true, g_tol = tol))
    opt = Optim.minimizer(results)
    θstar = reshape(opt, k, m)
    _, prob = probability(θstar)

    return results, θstar[:, 2:end], likelihood(opt), prob
end

I know that the problem is the Hessian because the function works perfectly fine if I take it out (when I use it, the function runs but the problem does not converge). Can you help me to figure out what I am doing wrong here?

Upvotes: 0

Views: 72

Answers (0)

Related Questions