Reputation: 285
I am trying to use quantile in a loss function to train! (for some robustness, like least trimmed squares), but it mutates the array and Zygote throws an error Mutating arrays is not supported
, coming from sort!
. Below is a simple example (the content does not make sense of course):
using Flux, StatsBase
xdata = randn(2, 100)
ydata = randn(100)
model = Chain(Dense(2,10), Dense(10, 1))
function trimmedLoss(x,y; trimFrac=0.f05)
yhat = model(x)
absRes = abs.(yhat .- y) |> vec
trimVal = quantile(absRes, 1.f0-trimFrac)
s = sum(ifelse.(absRes .> trimVal, 0.f0 , absRes ))/(length(absRes)*(1.f0-trimFrac))
#s = sum(absRes)/length(absRes) # using this and commenting out the two above works (no surprise)
end
println(trimmedLoss(xdata, ydata)) #works ok
Flux.train!(trimmedLoss, params(model), zip([xdata], [ydata]), ADAM())
println(trimmedLoss(xdata, ydata)) #changed loss?
This is all in Flux 0.10 with Julia 1.2
Thanks in advance for any hints or workaround!
Upvotes: 7
Views: 593
Reputation: 1406
Ideally, we'd define a custom adjoint for quantile
so that this works out of the box. (Feel free to open an issue to remind us to do this.)
In the mean time there's a quick workaround. It's actually the sorting that causes trouble here so if you do quantile(xs, p, sorted=true)
it'll work. Obviously this requires xs
to be sorted to get correct results, so you might need to use quantile(sort(xs), ...)
.
Depending on your Zygote version you might also need an adjoint for sort
. That one's pretty easy:
julia> using Zygote: @adjoint
julia> @adjoint function sort(x)
p = sortperm(x)
x[p], x̄ -> (x̄[invperm(p)],)
end
julia> gradient(x -> quantile(sort(x), 0.5, sorted=true), [1, 2, 3, 3])
([0.0, 0.5, 0.5, 0.0],)
We'll make that built-in in the next Zygote release, but for now if you add that to your script it'll get your code working.
Upvotes: 9