Reputation: 1
Suppose you have a vector-valued funtion in pytorch f:R^n -> R^m.
Computing its hessian should return a [M, N, N] tensor, and its diagonals would return a [M,N] matrix
Computing this whole hessian then diagonalizing is very expensive (sometimes impossible even with an A100 that has 80GB of GPU memory) for large functions like neural nets.
Note: My goal is to compute the hessian wrt the input to the function / neural net, not wrt its parameters
There also already exists a technique provided by pytorch to do something similar when dealing with a scalar-valued function (f:R^n -> R, and its hessian diagonal would return a [N] vector) using hvp (see in code below), but doesn't easily extend to my problem setting.
I've tried coming up with my own implementation with hvp_vecfwd (see code below). It does produce the correct shape of tensor, but the values inside are not all correct even on simple vector functions (f_single_vector in code below).
This question seems to have some possible solutions in Jax/Tensorflow, but none that I've seen that solves it (i) in PyTorch and (ii) for vector-valued functions.
Any help or insights on this would be great!
import torch
from torch.func import vmap, jvp, vjp, grad, jacrev, jacfwd, hessian
from functools import partial
x_single = torch.tensor([1.0, 2.0, 3.0], requires_grad=False)
x_batch = torch.tensor([[1.0, 2.0, 3.0],
[4.0, 5.0, 6.0],
[7.0, 8.0, 9.0]], requires_grad=False)
layer = torch.nn.Linear(3, 5).requires_grad_(False)
# Simple functions for debugging
# For Scalar Functions
def hvp(f, x, v):
return jvp(grad(f), (x,), (v,))[1]
# For Vector Functions
def hvp_vecfwd(f, x, v):
return jvp(jacrev(f), (x,), (v,))[1]
# Batched f: [B, R^N] -> [B, R]
def f_batch_scalar(x):
return torch.sum(x*x, dim=-1)
# f: R^N -> R^M
def f_single_vector(x):
return torch.sigmoid(layer(x))
# Batched f: [B, R^N] -> [B, R^M]
def f_batch_vector(x):
return torch.sigmoid(layer(x))
print("Batched Full Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, d^2f/dx^2 = [B, N, N]
print(vmap(hessian(f_batch_scalar))(x_batch))
print("Batched Diagonal Hessian of Scalar Function")
# Batched f: [B, R^N] -> R, diag[d^2f/dx^2 H] = [B, N]
print(vmap(hvp, in_dims=(None, 0, 0))(f_batch_scalar, x_batch, torch.ones_like(x_batch)))
print("Non-Batched Full Hessian of Vector Function")
# f: R^N -> R^M, d^2f/dx^2 = [M, N, N]
print(hessian(f_single_vector)(x_single))
print("Non-Batched Diagonal Hessian of Vector Function")
# f: R^N -> R^M, diag[d^2f/dx^2 H] = [M, N]
# WRONG
print(hvp_vecfwd(f_single_vector, x_single, torch.ones_like(x_single)))
print("Batched Full Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M] d^2f/dx^2 = [B, M, N, N]
print(vmap(hessian(f_batch_vector))(x_batch))
print("Batched Diagonal Hessian of Vector Function")
# Batched f: [B, R^N} -> [B, R^M], diag[d^2f/dx^2 H] = [B, M, N]
# WRONG
print(vmap(hvp_vecfwd, in_dims=(None, 0, 0))(f_batch_vector, x_batch, torch.ones_like(x_batch)))
Upvotes: 0
Views: 104