user24521305
user24521305

Reputation: 1

Finding the Diagonal of Hessian wrt Input for Vector-Valued Functions in PyTorch

Suppose you have a vector-valued funtion in pytorch f:R^n -> R^m.

Computing its hessian should return a [M, N, N] tensor, and its diagonals would return a [M,N] matrix

Computing this whole hessian then diagonalizing is very expensive (sometimes impossible even with an A100 that has 80GB of GPU memory) for large functions like neural nets.

Note: My goal is to compute the hessian wrt the input to the function / neural net, not wrt its parameters

There also already exists a technique provided by pytorch to do something similar when dealing with a scalar-valued function (f:R^n -> R, and its hessian diagonal would return a [N] vector) using hvp (see in code below), but doesn't easily extend to my problem setting.

I've tried coming up with my own implementation with hvp_vecfwd (see code below). It does produce the correct shape of tensor, but the values inside are not all correct even on simple vector functions (f_single_vector in code below).

This question seems to have some possible solutions in Jax/Tensorflow, but none that I've seen that solves it (i) in PyTorch and (ii) for vector-valued functions.

Any help or insights on this would be great!

import torch
from torch.func import vmap, jvp, vjp, grad, jacrev, jacfwd, hessian
from functools import partial

x_single = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
x_batch = torch.tensor([[1.0, 2.0, 3.0],
                        [4.0, 5.0, 6.0],
                        [7.0, 8.0, 9.0]], requires_grad=False)
layer = torch.nn.Linear(3, 5).requires_grad_(False)


# Simple functions for debugging

# For Scalar Functions
def hvp(f, x, v):
    return jvp(grad(f), (x,), (v,))[1]

# For Vector Functions
def hvp_vecfwd(f, x, v):
    return jvp(jacrev(f), (x,), (v,))[1]


# Batched f: [B, R^N] -> [B, R]
def f_batch_scalar(x):
    return torch.sum(x*x, dim=-1)


# f: R^N -> R^M
def f_single_vector(x):
    return torch.sigmoid(layer(x))


# Batched f: [B, R^N] -> [B, R^M]
def f_batch_vector(x):
    return torch.sigmoid(layer(x))


print("Batched Full Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, d^2f/dx^2 = [B, N, N]
print(vmap(hessian(f_batch_scalar))(x_batch))

print("Batched Diagonal Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, diag[d^2f/dx^2 H] = [B, N]
print(vmap(hvp, in_dims=(None, 0, 0))(f_batch_scalar, x_batch, torch.ones_like(x_batch)))

print("Non-Batched Full Hessian of Vector Function")
# f: R^N -> R^M, d^2f/dx^2 = [M, N, N]
print(hessian(f_single_vector)(x_single))

print("Non-Batched Diagonal Hessian of Vector Function")
# f: R^N -> R^M, diag[d^2f/dx^2 H] = [M, N]
# WRONG
print(hvp_vecfwd(f_single_vector, x_single,  torch.ones_like(x_single)))

print("Batched Full Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M] d^2f/dx^2 = [B, M, N, N]
print(vmap(hessian(f_batch_vector))(x_batch))

print("Batched Diagonal Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M], diag[d^2f/dx^2 H] = [B, M, N]
# WRONG
print(vmap(hvp_vecfwd, in_dims=(None, 0, 0))(f_batch_vector, x_batch, torch.ones_like(x_batch)))

Upvotes: 0

Views: 104

Answers (0)

Related Questions