luciano-drozda
luciano-drozda

Reputation: 71

Julia: Zygote.@adjoint from Enzyme.autodiff

Given the function f! below :

function f!(s::Vector, a::Vector, b::Vector)
  
  s .= a .+ b
  return nothing

end # f!

How can I define an adjoint for Zygote based on

Enzyme.autodiff(f!, Const, Duplicated(s, dz_ds). Duplicated(a, zero(a)), Duplicated(b, zero(b))) ?

Zygote.@adjoint f!(s, a, b) = f!(s, a, b), # What would come here ?

Upvotes: 3

Views: 414

Answers (1)

luciano-drozda
luciano-drozda

Reputation: 71

Could figure out a way, sharing it here.

For a given function foo, Zygote.pullback(foo, args...) returns foo(args...) and the backward pass (which allows for gradients computations).

My goal is to tell Zygote to use Enzyme for the backward pass.

This can be done by means of Zygote.@adjoint (see more here).

In case of array-valued functions, Enzyme requires a mutating version that returns nothing and its result to be in args (see more here).

The function f! in the question post is an Enzyme-compatible version of a sum of two arrays.

Since f! returns nothing, Zygote would simply return nothing when the backward pass is called on some gradient passed to us.

A solution is to place f! inside a wrapper (say f) that returns the array s

and to define Zygote.@adjoint for f, rather than f!.

Hence,

function f(a::Vector, b::Vector)

  s = zero(a)
  f!(s, a, b)
  return s

end
function enzyme_back(dzds, a, b)

  s    = zero(a)
  dzda = zero(dzds)
  dzdb = zero(dzds)
  Enzyme.autodiff(
    f!,
    Const,
    Duplicated(s, dzds),
    Duplicated(a, dzda),
    Duplicated(b, dzdb)
  )
  return (dzda, dzdb)

end

and

Zygote.@adjoint f(a, b) = f(a, b), dzds -> enzyme_back(dzds, a, b)

inform Zygote to use Enzyme in the backward pass.


Finally, you can check that calling Zygote.gradient either on

g1(a::Vector, b::Vector) = sum(abs2, a + b)

or

g2(a::Vector, b::Vector) = sum(abs2, f(a, b))

yields the same results.

Upvotes: 4

Related Questions