Reputation: 71
Given the function f!
below :
function f!(s::Vector, a::Vector, b::Vector)
s .= a .+ b
return nothing
end # f!
How can I define an adjoint for Zygote based on
Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b)))
?
Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?
Upvotes: 3
Views: 414
Reputation: 71
Could figure out a way, sharing it here.
For a given function foo
, Zygote.pullback(foo, args...)
returns foo(args...)
and the backward pass (which allows for gradients computations).
My goal is to tell Zygote
to use Enzyme
for the backward pass.
This can be done by means of Zygote.@adjoint
(see more here).
In case of array-valued functions, Enzyme
requires a mutating version that returns nothing
and its result to be in args
(see more here).
The function f!
in the question post is an Enzyme
-compatible version of a sum of two arrays.
Since f!
returns nothing
, Zygote
would simply return nothing
when the backward pass is called on some gradient passed to us.
A solution is to place f!
inside a wrapper (say f
) that returns the array s
and to define Zygote.@adjoint
for f
, rather than f!
.
Hence,
function f(a::Vector, b::Vector)
s = zero(a)
f!(s, a, b)
return s
end
function enzyme_back(dzds, a, b)
s = zero(a)
dzda = zero(dzds)
dzdb = zero(dzds)
Enzyme.autodiff(
f!,
Const,
Duplicated(s, dzds),
Duplicated(a, dzda),
Duplicated(b, dzdb)
)
return (dzda, dzdb)
end
and
Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)
inform Zygote
to use Enzyme
in the backward pass.
Finally, you can check that calling Zygote.gradient
either on
g1(a::Vector, b::Vector) = sum(abs2, a + b)
or
g2(a::Vector, b::Vector) = sum(abs2, f(a, b))
yields the same results.
Upvotes: 4