Jingyang Wang
Jingyang Wang

Reputation: 323

JAX custom_jvp with 'None' output leads to TypeError

I try to define a function whose jvp is only defined for selected output(s). Below is a simple example:

from jax import custom_jvp, jacobian

@custom_jvp
def func(x, y):

    return x+y, x*y


@func.defjvp
def func_jvp(primals, tangents):

    x, y = primals
    t0, t1 = tangents

    primals_out = func(x, y)
    tangents_out = (t0+t1, None)

    return primals_out, tangents_out


if __name__ == "__main__":

    x = 1.
    y = 2.

    print(jacobian(func)(x, y))

The error says:

TypeError: Custom JVP rule func_jvp for function func must produce primal and tangent outputs with equal container (pytree) structures, but got PyTreeDef((*, *)) and PyTreeDef((*, None)) respectively.

Is there a workaround for this case?

Upvotes: 1

Views: 70

Answers (1)

jakevdp
jakevdp

Reputation: 86513

The custom_jvp rule must return tangents for both outputs; None isn't a valid option, so it leads to a TypeError. The best way to address this would be to return the appropriate tangent for both outputs:

@func.defjvp
def func_jvp(primals, tangents):

    x, y = primals
    t0, t1 = tangents

    primals_out = func(x, y)
    tangents_out = (t0+t1, t0 * y + t1 * x)

    return primals_out, tangents_out

This is the correct output tangent by the product derivative rule: d(x·y) = dx·y + dy·x.

If your goal is for some tangent outputs to be computed automatically and some to be computed manually, you could achieve this by calling the jax.jvp transformation within your JVP rule. For example:

@func.defjvp
def func_jvp(primals, tangents):
    x, y = primals
    t0, t1 = tangents

    # Compute tangents automatically
    primals_out, tangents_out = jax.jvp(func.fun, primals, tangents)

    # Replace selected tangents with manual values
    tangents_out = (t0 + t1, tangents_out[1])

    return primals_out, tangents_out

This uses func.fun, which is where the original non-custom-jvp function definition is stored in a custom_jvp object.

Upvotes: 1

Related Questions