10GeV
10GeV

Reputation: 475

Solving this convex optimization problem in Python, using only SciPy

Background

I'm attempting to solve the following convex optimization problem in , using (ideally) only the package. The goal is to recover the matrix 𝐺. The matrices 𝒥ₘ, 𝒥ₖ, 𝒯, and the integers ,𝓂,𝓀, and 𝓃 are known and fixed.

enter image description here

I'm given some code that accomplishes this task, using CVXR:

J_M = eye(M) - 1/M*ones(M);
J_K = eye(K) - 1/K*ones(K);

S_row = [eye(M) zeros(M, K)];
S_col = [zeros(M, K); eye(K)];

cvx_begin sdp quiet
    cvx_precision low
    cvx_solver sedumi

    variable G(N, N) symmetric;
    variable B(M, K);

    DofG = diag(G)*ones(N, 1)' - 2*G + ones(N, 1)*diag(G)';
    LofG = S_row*DofG*S_col;

    G >= 0;
    G*ones(N, 1) == 0;

    L = cell(M, K);
    for m = 1:M
        for k = 1:K
            L{m, k} = [LofG(m, k) B(m, k); B(m, k) 1];
            L{m, k} >= 0;
        end
    end

    B(:) >= 0;

    minimize square_pos(norm(J_M*(B - T)*J_K, 'fro'))
cvx_end

Unfortunately, my knowledge of complex optimization techniques is limited (although the problem itself is, in principle, fairly straightforward), and I have no experience. Some pieces of the translation are straightforward, while others I struggle with.


Attempt

I've written an (admittedly very sparse) skeleton of an attempt to recreate some pieces of the code.

from scipy.optimize import minimize
import numpy as np

def loss(...):
    # M and K are given by size of T
    J_M = eye(M) - 1/M * ones(M)
    J_K = eye(K) - 1/K * ones(K)

    DofG = diag(G) @ ones(N, 1) - 2 * G + ones(N, 1) @ diag(G)
    LofG = S_row @ DofG @ S_col

    return np.linalg.norm(J_M @ (B - T) @ J_K, 'fro')**2

res = minimize(loss, ...)


What's Missing

I'm not sure what G(N, N); symmetric and B(M, K) do, nor how be replicated in the code. I can't seem to find a simple explanation in the CVXR documentation.

I'm also unsure how to replicate the constraints on B and G, and I'm unsure how to replicate the "iterated constraint":

    for m = 1:M
        for k = 1:K
            L{m, k} = [LofG(m, k) B(m, k); B(m, k) 1];
            L{m, k} >= 0;
        end
    end

How can this Matlab implementation be translated to Python/SciPy

Upvotes: 0

Views: 728

Answers (1)

Bob
Bob

Reputation: 14654

Direct translation using cvxpy

import cvxpy as cp
import numpy as np

def solve(T, verbose=True):
    M,K = T.shape
    N = M + K
    J_M = np.eye(M) - np.ones((M,M))/M
    J_K = np.eye(K) - np.ones((K,K))/K
    S_row = np.eye(M, M+K)
    S_col = np.roll(np.eye(M+K,K), M, axis=0)

    G = cp.Variable((N,N), symmetric=True)
    B = cp.Variable((M,K))

    # Here I am assuming that the product of diag(G)*ones(N,1) is broadcasting
    Gd = cp.reshape(cp.diag(G), (N, 1))
    DofG = Gd + Gd.T - 2*G
    # Imposes N = M + K
    LofG = S_row @ DofG @ S_col
    constraints = [G >> 0, cp.sum(G, axis=1) == 0]
    for m in range(M):
        for k in range(K):
            L = cp.bmat([[LofG[m,k], B[m,k]], [B[m,k],1]])
            constraints.append(L >> 0)
    constraints.append(B >= 0)

    obj = cp.norm(J_M @ (B - T) @ J_K, 'fro')
    prob = cp.Problem(cp.Minimize(obj), constraints)
    prob.solve(verbose=verbose)
    return G.value, B.value, obj.value

If you are familiar with numpy the above code should be easy to follow. The only difference is that instead of using numpy functions I use the corresponding cvxpy functions that can manipulate variables. The objects store variables and relations between them, when prob.solve is called it will cast the problem to a standard form and submit to some solver.

solve(np.eye(3, 2) + 1)
(array([[ 0.68113622, -0.22048912, -0.10273289, -0.55665717,  0.19874296],
        [-0.22048912,  0.68113622, -0.10273289,  0.19874296, -0.55665717],
        [-0.10273289, -0.10273289,  0.08445099,  0.0605074 ,  0.0605074 ],
        [-0.55665717,  0.19874296,  0.0605074 ,  0.48298721, -0.1855804 ],
        [ 0.19874296, -0.55665717,  0.0605074 , -0.1855804 ,  0.48298721]]),
 array([[ 1.17174172e+00,  1.71742424e-01],
        [ 1.71742424e-01,  1.17174172e+00],
        [ 1.50334549e-15, -1.47036515e-16]]),
 6.988985781497536e-07)

Upvotes: 1

Related Questions