Reputation: 79
I have a PyTree params
(in my case a nested dictionary) containing my parameters of a neural network. My goal is to compute the diagonal entries of the Hessian of a loss function with respect to the parameters and store it in a PyTree of the same structure as the parameters.
When I call jax.hessian(loss_fn)(params, data)
, I get a (as expected) an even more nested dictionary with the full Hessian.
How can I transform this dictionary to get the desired PyTree with diagonal entries?
To be more concrete: Lets say I have only 1 layer in my network and params
is given by
params:
'linear':
'w': DeviceArray() of shape [5 x 1]
'b': DeviceArray() of shape [1]
The returned Hessian has the keys and shape given by
hessian:
'linear':
'b':
'linear':
'b': (1, 1),
'w': (1, 5, 1),
'w':
'linear':
'b': (5, 1, 1),
'w': (5, 1, 5, 1)
As far as I understand it, I need the entries
jnp.diag(hessian['linear']['b']['linear']['b'])
as the diagonal hessian for the bias and
jnp.diag(jnp.squeeze(hessian['linear']['w']['linear']['w']))
as the diagonal hessian for the weights. (However, the squeeze may only work for 1 dim outputs...)
How can I automate this transformation in order to work for more complex models with multiple layers?
I know that this does not scale to huge networks, I need it for testing purposes of optimizers.
Upvotes: 2
Views: 618
Reputation: 35
I ran into the exact same problem. Unfortunately, working with Pytrees in Jax can be awkward. I was also looking at a way to construct the diagonal Hessian entry-for-entry, since that could yield a practical method.
I now have the following:
def ravelled_diagonal_indices(dims: Sequence[int]) -> jnp.ndarray:
# Get the indices for the diagonal elements of a flattened square matrix.
return (dims[0] + 1) * jnp.arange(dims[0])
# Alias to reduce clutter.
_diag_idx = ravelled_diagonal_indices
def tree_matrix_diagonal(tree: Any, reference: Optional[Any] = None) -> Any:
"""Utility function for extracting the diagonal of a Pytree of jax.numpy.array objects.
The Pytree is assumed to be square in its children and in its array objects.
Parameters
----------
tree : Any
Pytree of jax.numpy.array objects for which the number of Pytree leaves and
the sizes of each constituent array is square.
reference : Any, default = None
The intended structure for the diagonal of `tree`. For example, this can be
the Pytree with which `tree` could have been created through e.g., an outer-product
or the Hessian of a function.
Returns
-------
diag : Any
Pytree containing the flattened diagonals of `tree` if no reference was provided.
Otherwise, the diagonal elements are shaped according to the structure of `reference`.
"""
flat = jax.tree_leaves(tree)
h = jax.numpy.sqrt(len(flat)).astype(int)
_idx = _diag_idx((h,))
block_diag = [flat[i] for i in _idx]
flat_diagonal = lambda w: w.ravel()[_diag_idx((jax.numpy.sqrt(w.size).astype(int),))]
diag = jax.tree_map(flat_diagonal, block_diag)
if reference is not None:
# Reshape the diagonal Pytree to reference Pytree structure and shape
diag_tree = jax.tree_unflatten(jax.tree_structure(reference), diag)
diag = jax.tree_multimap(lambda a, b: a.reshape(jax.numpy.shape(b)), diag_tree, reference)
return diag
When I try this out on the Hessian of a very simple MLP:
params
>> {'dense/~/affine': {'weights': DeviceArray([[ 1. , 1. ],
[ 0.546326 , -0.77997607]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[ 1. ],
[-0.5155028],
[ 0.9487318]], dtype=float32)}}
hessian
>> {'dense/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[-0.02324889, 0.04278728],
[ 0.00814307, -0.01498652]],
[[ 0.04278728, -0.07874574],
[-0.01498652, 0.0275812 ]]],
[[[ 0.00814307, -0.01498652],
[-0.00285216, 0.00524912]],
[[-0.01498652, 0.0275812 ],
[ 0.00524912, -0.00966049]]]]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[[[[ 0.04509945],
[ 0.15897979],
[ 0.05742025]],
[[-0.08300105],
[-0.06711845],
[ 0.01683405]]],
[[[-0.01579637],
[-0.05568369],
[-0.02011181]],
[[ 0.02907166],
[ 0.02350867],
[-0.00589623]]]]], dtype=float32)}}},
'dense_1/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[ 0.04509945, -0.08300105],
[-0.01579637, 0.02907165]]],
[[[ 0.15897979, -0.06711845],
[-0.0556837 , 0.02350867]]],
[[[ 0.05742024, 0.01683406],
[-0.02011181, -0.00589624]]]]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[[[[-0.08748633],
[-0.07074545],
[-0.11138687]]],
[[[-0.07074545],
[-0.05720801],
[-0.09007253]]],
[[[-0.11138687],
[-0.09007251],
[-0.14181684]]]]], dtype=float32)}}}}
Then, the function returns:
tree_matrix_diagonal(hessian, reference=params)
>> {'dense/~/affine': {'weights': DeviceArray([[-0.02324889, -0.07874574],
[-0.00285216, -0.00966049]], dtype=float32)},
'dense_1/~/affine': {'weights': DeviceArray([[-0.08748633],
[-0.05720801],
[-0.14181684]], dtype=float32)}}
Upon visual inspection, you can see that the returned elements are indeed the diagonal elements of hessian
cast to the canonical structure of params
.
Funnily enough, for the Gauss-Newton approximation to the Hessian the procedure is much simpler. Simply take the element-wise square of the Jacobians :).
Upvotes: 1