Philipp D.
Philipp D.

Reputation: 79

Extract the diagonal elements of the Hessian in a neural network in Jax

I have a PyTree params (in my case a nested dictionary) containing my parameters of a neural network. My goal is to compute the diagonal entries of the Hessian of a loss function with respect to the parameters and store it in a PyTree of the same structure as the parameters.

When I call jax.hessian(loss_fn)(params, data), I get a (as expected) an even more nested dictionary with the full Hessian.

How can I transform this dictionary to get the desired PyTree with diagonal entries?

To be more concrete: Lets say I have only 1 layer in my network and paramsis given by

params:
    'linear':
        'w': DeviceArray() of shape [5 x 1]
        'b': DeviceArray() of shape [1]

The returned Hessian has the keys and shape given by

hessian:
    'linear': 
        'b': 
            'linear': 
                'b': (1, 1), 
                'w': (1, 5, 1), 
        'w': 
            'linear': 
                'b': (5, 1, 1), 
                'w': (5, 1, 5, 1)

As far as I understand it, I need the entries

jnp.diag(hessian['linear']['b']['linear']['b'])

as the diagonal hessian for the bias and

jnp.diag(jnp.squeeze(hessian['linear']['w']['linear']['w']))

as the diagonal hessian for the weights. (However, the squeeze may only work for 1 dim outputs...)

How can I automate this transformation in order to work for more complex models with multiple layers?

I know that this does not scale to huge networks, I need it for testing purposes of optimizers.

Upvotes: 2

Views: 618

Answers (1)

Joery
Joery

Reputation: 35

I ran into the exact same problem. Unfortunately, working with Pytrees in Jax can be awkward. I was also looking at a way to construct the diagonal Hessian entry-for-entry, since that could yield a practical method.

I now have the following:

def ravelled_diagonal_indices(dims: Sequence[int]) -> jnp.ndarray:
    # Get the indices for the diagonal elements of a flattened square matrix.
    return (dims[0] + 1) * jnp.arange(dims[0])


# Alias to reduce clutter.
_diag_idx = ravelled_diagonal_indices


def tree_matrix_diagonal(tree: Any, reference: Optional[Any] = None) -> Any:
    """Utility function for extracting the diagonal of a Pytree of jax.numpy.array objects.

    The Pytree is assumed to be square in its children and in its array objects.

    Parameters
    ----------
    tree : Any
        Pytree of jax.numpy.array objects for which the number of Pytree leaves and 
        the sizes of each constituent array is square.
    reference : Any, default = None
        The intended structure for the diagonal of `tree`. For example, this can be 
        the Pytree with which `tree` could have been created through e.g., an outer-product 
        or the Hessian of a function.

    Returns
    -------
    diag : Any
        Pytree containing the flattened diagonals of `tree` if no reference was provided. 
        Otherwise, the diagonal elements are shaped according to the structure of `reference`.

    """
    flat = jax.tree_leaves(tree)

    h = jax.numpy.sqrt(len(flat)).astype(int)
    _idx = _diag_idx((h,))
    block_diag = [flat[i] for i in _idx]

    flat_diagonal = lambda w: w.ravel()[_diag_idx((jax.numpy.sqrt(w.size).astype(int),))]
    diag = jax.tree_map(flat_diagonal, block_diag)

    if reference is not None:
        # Reshape the diagonal Pytree to reference Pytree structure and shape
        diag_tree = jax.tree_unflatten(jax.tree_structure(reference), diag)
        diag = jax.tree_multimap(lambda a, b: a.reshape(jax.numpy.shape(b)), diag_tree, reference)

    return diag

When I try this out on the Hessian of a very simple MLP:

params
>> {'dense/~/affine': {'weights': DeviceArray([[ 1.        ,  1.        ],
               [ 0.546326  , -0.77997607]], dtype=float32)},
 'dense_1/~/affine': {'weights': DeviceArray([[ 1.       ],
               [-0.5155028],
               [ 0.9487318]], dtype=float32)}}


hessian
>> {'dense/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[-0.02324889,  0.04278728],
                    [ 0.00814307, -0.01498652]],
    
                   [[ 0.04278728, -0.07874574],
                    [-0.01498652,  0.0275812 ]]],
    
    
                  [[[ 0.00814307, -0.01498652],
                    [-0.00285216,  0.00524912]],
    
                   [[-0.01498652,  0.0275812 ],
                    [ 0.00524912, -0.00966049]]]]], dtype=float32)},
   'dense_1/~/affine': {'weights': DeviceArray([[[[[ 0.04509945],
                    [ 0.15897979],
                    [ 0.05742025]],
    
                   [[-0.08300105],
                    [-0.06711845],
                    [ 0.01683405]]],
    
    
                  [[[-0.01579637],
                    [-0.05568369],
                    [-0.02011181]],
    
                   [[ 0.02907166],
                    [ 0.02350867],
                    [-0.00589623]]]]], dtype=float32)}}},
 'dense_1/~/affine': {'weights': {'dense/~/affine': {'weights': DeviceArray([[[[[ 0.04509945, -0.08300105],
                    [-0.01579637,  0.02907165]]],
    
    
                  [[[ 0.15897979, -0.06711845],
                    [-0.0556837 ,  0.02350867]]],
    
    
                  [[[ 0.05742024,  0.01683406],
                    [-0.02011181, -0.00589624]]]]], dtype=float32)},
   'dense_1/~/affine': {'weights': DeviceArray([[[[[-0.08748633],
                    [-0.07074545],
                    [-0.11138687]]],
    
    
                  [[[-0.07074545],
                    [-0.05720801],
                    [-0.09007253]]],
    
    
                  [[[-0.11138687],
                    [-0.09007251],
                    [-0.14181684]]]]], dtype=float32)}}}}

Then, the function returns:

tree_matrix_diagonal(hessian, reference=params)
>> {'dense/~/affine': {'weights': DeviceArray([[-0.02324889, -0.07874574],
               [-0.00285216, -0.00966049]], dtype=float32)},
 'dense_1/~/affine': {'weights': DeviceArray([[-0.08748633],
               [-0.05720801],
               [-0.14181684]], dtype=float32)}}

Upon visual inspection, you can see that the returned elements are indeed the diagonal elements of hessian cast to the canonical structure of params.

Funnily enough, for the Gauss-Newton approximation to the Hessian the procedure is much simpler. Simply take the element-wise square of the Jacobians :).

Upvotes: 1

Related Questions