Siegg
Siegg

Reputation: 119

What would be a proper way of implementing this riemann gradient?

So here’s a toy example from this paper,Automatic differentiation for Riemannian optimization on low-rank matrix and tensor-train manifolds:

import torch
import torch.nn as nn

def f(X):
    return torch.sum(X**2)

def g(delta_U, delta_V, U, V, f):
    perturbed_matrix = U @ delta_V.t() + delta_U @ V.t()
    return f(perturbed_matrix)

def compute_riemannian_gradient(X):
    U, S, V = torch.svd(X)
    delta_U = U @ torch.diag(S)
    delta_V = torch.zeros_like(V)
    delta_U.requires_grad_(True)
    delta_V.requires_grad_(True)
    perturbed_value = g(delta_U, delta_V, U, V, f)
    perturbed_value.backward()

    return delta_U.grad, delta_V.grad

def apply_gauge_conditions(delta_U, delta_V, V):
    delta_V -= V @ (V.t() @ delta_V)
    return delta_U, delta_V

def riemannian_gradient(X):
    U, _, V = torch.svd(X)
    delta_U, delta_V = compute_riemannian_gradient(X)
    delta_U, delta_V = apply_gauge_conditions(delta_U, delta_V, V)
    return delta_U @ V.t() + U @ delta_V.t()

X = torch.randn(5, 3)
y = X**2 + 0.1*torch.randn_like(X)
rgrad = riemannian_gradient(X)

for i in range(10):
    rgrad = riemannian_gradient(X)
    X = X - 0.01*rgrad
    # X = retraction(X, rgrad, 0.01)
    print(f(X))

So as you can see, in the training inference, I don’t need the gradient of X, or [U, S, V]. Instead, I need the gradient from delta_U and delta_V to update X. Therefore I'm not able to simply loop through the parameters registered in parameters if I want to integrate this piece of code into torch.optimizer module.

My question is what’s the proper way of implementing this optimizing algorithm in optim.step() function when the weight X is updated by gradients from other parameters?

Upvotes: 0

Views: 54

Answers (1)

prateek
prateek

Reputation: 11

I do not understand what you mean by "training inference", and overall, the question needs more clarity.

As I understand from a quick look at the paper, they provide you with a special algorithm to compute derivatives of low-rank matrices.

In Torch, you can define a custom function with a forward and backward method using torch.autograd.Function as succinctly described in this gist. Refer to official documentation for an extensive overview.

This class should handle the gradient computation step while the optimizer will be instantiated with [X] as the list of variable to optimize on.

Upvotes: 0

Related Questions