Reputation: 14521
In PyTorch, you commonly have to zero our the gradients before doing back propagation. Is this the case in Flux? If so, what is the programatic way of doing this?
Upvotes: 0
Views: 337
Reputation: 2580
No, there is no need.
Flux used to use Tracker, a differentiation system in which each tracked array may hold a gradient. I think this is a similar design to pytorch. Back-propagating twice can lead to the problem which zeroing is intended to avoid (although the defaults try to protect you):
julia> using Tracker
julia> x_tr = Tracker.param([1 2 3])
Tracked 1×3 Matrix{Float64}:
1.0 2.0 3.0
julia> y_tr = sum(abs2, x_tr)
14.0 (tracked)
julia> Tracker.back!(y_tr, 1; once=false)
julia> x_tr.grad
1×3 Matrix{Float64}:
2.0 4.0 6.0
julia> Tracker.back!(y_tr, 1; once=false) # by default (i.e. with once=true) this would be an error
julia> x_tr.grad
1×3 Matrix{Float64}:
4.0 8.0 12.0
Now it uses Zygote, which does not use tracked array types. Instead, the evaluation to be traced must happen with the call to Zygote.gradient
, it can then see and manipulate the source code to write new code for the gradient. Repeated calls to this generate the same gradients each time; there is no stored state to need cleaning up.
julia> using Zygote
julia> x = [1 2 3] # an ordinary Array
1×3 Matrix{Int64}:
1 2 3
julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)
julia> Zygote.gradient(x -> sum(abs2, x), x)
([2 4 6],)
julia> y, bk = Zygote.pullback(x -> sum(abs2, x), x);
julia> bk(1.0)
([2.0 4.0 6.0],)
julia> bk(1.0)
([2.0 4.0 6.0],)
Tracker can also be used this way, rather than handling param
and back!
yourself:
julia> Tracker.gradient(x -> sum(abs2, x), [1, 2, 3])
([2.0, 4.0, 6.0] (tracked),)
Upvotes: 1