user19095
user19095

Reputation: 105

hessian matrix of a keras model with tf.hessians

I want to compute the Hessian matrix of a keras model w.r.t. its input in graph mode using tf.hessians. Here is a minimal example

import tensorflow as tf
from tensorflow import keras

model = keras.Sequential([
    keras.Input((10,)),
    keras.layers.Dense(1)
])
model.summary()

@tf.function
def get_grads(inputs):
    loss = tf.reduce_sum(model(inputs))
    return tf.gradients(loss, inputs)

@tf.function
def get_hessian(inputs):
    loss = tf.reduce_sum(model(inputs))
    return tf.hessians(loss, inputs)

batch_size = 3
test_input = tf.random.uniform((batch_size, 10))
out = model(test_input) # works fine
grads = get_grads(test_input) # works fine
hessian = get_hessian(test_input) # raises ValueError: None values not supported.

While the forward pass and the get_grads function work fine, the get_hessian function raises the ValueError: None values not supported..

In this example

@tf.function
def get_hessian_(inputs):
    loss = tf.reduce_sum(inputs**2)
    return tf.hessians(loss, inputs)

get_hessian_(tf.random.uniform((3,)))[0]
# <tf.Tensor: shape=(3, 3), dtype=float32, numpy=
# array([[2., 0., 0.],
#        [0., 2., 0.],
#        [0., 0., 2.]], dtype=float32)>

tf.hessians yields the expected result without error.

Upvotes: 0

Views: 1297

Answers (1)

Laplace Ricky
Laplace Ricky

Reputation: 1687

In your code example,

you are trying to get hessian of f(x)(model outputs) w.r.t. x(inputs) and f is linear (the model is linear).

Hessian of f(x) w.r.t. x should actually be a zero tensor, but tf.hessians can't handle that properly, resulting the error. Adding an additional layer with non-linear activation will eliminate the error.

Codes examples:

Using tf.hessians to get hessian:

model = tf.keras.Sequential([
    Dense(10,activation='sigmoid'), #remove this line and you will get error
    Dense(1)
])
@tf.function
def get_hessian(inputs):
    loss = tf.reduce_sum(model(inputs))
    return tf.hessians(loss, inputs)

batch_size = 3
tf.random.set_seed(123)
test_input = tf.random.uniform((3,10),minval=1.5,maxval=2.5)
hessian = get_hessian(test_input)
print(type(hessian))
print(len(hessian))
print(hessian[0].shape)
print(hessian[0][0,0,0,0])
print(hessian[0][0,0,0,1])
'''
<class 'list'>
1
(3, 10, 3, 10)
tf.Tensor(0.0028595054, shape=(), dtype=float32)
tf.Tensor(0.0009458237, shape=(), dtype=float32)
''' 

Using tf.GradientTape() to get hessian:

model = tf.keras.Sequential([
    Dense(10,activation='sigmoid'), #remove this line and get_hessian return None
    Dense(1)
])
@tf.function
def get_hessian(inputs):
    with tf.GradientTape() as t2:
      t2.watch(inputs)
      with tf.GradientTape() as t1:
        t1.watch(inputs)
        loss = tf.reduce_sum(model(inputs))
      g=t1.gradient(loss,inputs)
    return t2.jacobian(g,inputs)

batch_size = 3
tf.random.set_seed(123)
test_input = tf.random.uniform((3,10),minval=1.5,maxval=2.5)
hessian = get_hessian(test_input)
print(type(hessian))
print(hessian.shape if hessian is not None else None)
print(hessian[0,0,0,0] if hessian is not None else None)
print(hessian[0,0,0,1] if hessian is not None else None)
'''
<class 'tensorflow.python.framework.ops.EagerTensor'>
(3, 10, 3, 10)
tf.Tensor(0.0028595058, shape=(), dtype=float32)
tf.Tensor(0.0009458238, shape=(), dtype=float32)
'''

In case you want to get a zero tensor, you can use unconnected_gradients=tf.UnconnectedGradients.ZERO

model = tf.keras.Sequential([
    Dense(1)
])
@tf.function
def get_hessian(inputs):
    with tf.GradientTape() as t2:
      t2.watch(inputs)
      with tf.GradientTape() as t1:
        t1.watch(inputs)
        loss = tf.reduce_sum(model(inputs))
      g=t1.gradient(loss,inputs,unconnected_gradients=tf.UnconnectedGradients.ZERO)
    return t2.jacobian(g,inputs,unconnected_gradients=tf.UnconnectedGradients.ZERO)

batch_size = 3
tf.random.set_seed(123)
test_input = tf.random.uniform((3,10),minval=1.5,maxval=2.5)
hessian = get_hessian(test_input)
print(type(hessian))
print(hessian.shape)
print(tf.math.count_nonzero(hessian))
'''
<class 'tensorflow.python.framework.ops.EagerTensor'>
(3, 10, 3, 10)
tf.Tensor(0, shape=(), dtype=int64)
'''

Upvotes: 0

Related Questions