Reputation: 323
I try to define a function whose jvp is only defined for selected output(s). Below is a simple example:
from jax import custom_jvp, jacobian
@custom_jvp
def func(x, y):
return x+y, x*y
@func.defjvp
def func_jvp(primals, tangents):
x, y = primals
t0, t1 = tangents
primals_out = func(x, y)
tangents_out = (t0+t1, None)
return primals_out, tangents_out
if __name__ == "__main__":
x = 1.
y = 2.
print(jacobian(func)(x, y))
The error says:
TypeError: Custom JVP rule func_jvp for function func must produce primal and tangent outputs with equal container (pytree) structures, but got PyTreeDef((*, *)) and PyTreeDef((*, None)) respectively.
Is there a workaround for this case?
Upvotes: 1
Views: 70
Reputation: 86513
The custom_jvp
rule must return tangents for both outputs; None
isn't a valid option, so it leads to a TypeError
. The best way to address this would be to return the appropriate tangent for both outputs:
@func.defjvp
def func_jvp(primals, tangents):
x, y = primals
t0, t1 = tangents
primals_out = func(x, y)
tangents_out = (t0+t1, t0 * y + t1 * x)
return primals_out, tangents_out
This is the correct output tangent by the product derivative rule: d(x·y) = dx·y + dy·x.
If your goal is for some tangent outputs to be computed automatically and some to be computed manually, you could achieve this by calling the jax.jvp
transformation within your JVP rule. For example:
@func.defjvp
def func_jvp(primals, tangents):
x, y = primals
t0, t1 = tangents
# Compute tangents automatically
primals_out, tangents_out = jax.jvp(func.fun, primals, tangents)
# Replace selected tangents with manual values
tangents_out = (t0 + t1, tangents_out[1])
return primals_out, tangents_out
This uses func.fun
, which is where the original non-custom-jvp function definition is stored in a custom_jvp
object.
Upvotes: 1