logankilpatrick
logankilpatrick

Reputation: 14521

Is there a `zero_grad()` function in Flux.jl

In PyTorch, you commonly have to zero our the gradients before doing back propagation. Is this the case in Flux? If so, what is the programatic way of doing this?

Upvotes: 0

Views: 337

Answers (1)

mcabbott
mcabbott

Reputation: 2580

tl;dr

No, there is no need.

explanation

Flux used to use Tracker, a differentiation system in which each tracked array may hold a gradient. I think this is a similar design to pytorch. Back-propagating twice can lead to the problem which zeroing is intended to avoid (although the defaults try to protect you):

julia> using Tracker

julia> x_tr = Tracker.param([1 2 3])
Tracked 1×3 Matrix{Float64}:
 1.0  2.0  3.0

julia> y_tr = sum(abs2, x_tr)
14.0 (tracked)

julia> Tracker.back!(y_tr, 1; once=false)

julia> x_tr.grad
1×3 Matrix{Float64}:
 2.0  4.0  6.0

julia> Tracker.back!(y_tr, 1; once=false) # by default (i.e. with once=true) this would be an error

julia> x_tr.grad
1×3 Matrix{Float64}:
 4.0  8.0  12.0

Now it uses Zygote, which does not use tracked array types. Instead, the evaluation to be traced must happen with the call to Zygote.gradient, it can then see and manipulate the source code to write new code for the gradient. Repeated calls to this generate the same gradients each time; there is no stored state to need cleaning up.

julia> using Zygote

julia> x = [1 2 3]  # an ordinary Array
1×3 Matrix{Int64}:
 1  2  3

julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)

julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)

julia> y, bk = Zygote.pullback(x -> sum(abs2, x), x);

julia> bk(1.0)
([2.0 4.0 6.0],)

julia> bk(1.0)
([2.0 4.0 6.0],)

Tracker can also be used this way, rather than handling param and back! yourself:

julia> Tracker.gradient(x -> sum(abs2, x), [1, 2, 3])
([2.0, 4.0, 6.0] (tracked),)

Upvotes: 1

Related Questions