Reputation: 333
I am building and training a neural network model with Flux, and I am wondering if there is a way to take linear combinations of Zygote.Grads
types.
Here is a minimalistic example. This is how it is typically done:
m = hcat(2.0); b = hcat(-1.0); # random 1 x 1 matrices
f(x) = m*x .+ b
ps = Flux.params(m, b) # parameters to be adjusted
inputs = [0.3 1.5] # random 1 x 2 matrix
loss(x) = sum( f(x).^2 )
gs = Flux.gradient(() -> loss(inputs), ps) # the typical way
@show gs[m], gs[b] # 5.76, 3.2
But I want to do the same calculation by computing gradients at a deeper level, and then assembling it at the end. For example:
input1 = hcat(inputs[1, 1]); input2 = hcat(inputs[1, 2]); # turn each input into a 1 x 1 matrix
grad1 = Flux.gradient(() -> f(input1)[1], ps) # df/dp using input1 (where p is m or b)
grad2 = Flux.gradient(() -> f(input2)[1], ps) # df/dp using input2 (where p is m or b)
predicted1 = f(input1)[1]
predicted2 = f(input2)[1]
myGrad_m = (2 * predicted1 * grad1[m]) + (2 * predicted2 * grad2[m]) # 5.76
myGrad_b = (2 * predicted1 * grad1[b]) + (2 * predicted2 * grad2[b]) # 3.2
Above, I used the chain rule and linearity of the derivative to decompose the gradient of the loss()
function:
d(loss)/dp = d( sum(f^2) ) / dp = sum( d(f^2)/dp ) = sum( 2*f * df/dp )
Then, I calculated df/dp
using Zygote.gradient
, and then combined the results at the end.
But notice that I had to combine m
and b
separately. This was fine because there were only 2 parameters.
However, if there were a 1000 parameters, I would want to do something like this, which is a linear combination of the Zygote.Grads
:
myGrad = (2 * predicted1 * grad1) + (2 * predicted2 * grad2)
But, I get an error saying that the +
and *
operators are not defined for these types. How can I get this shortcut to work?
Upvotes: 2
Views: 206
Reputation: 925
Just turn each *
/+
into .*
/.+
(i.e. use broadcasting) or you can use map
to apply a function to multiple Grads
at once. This is described in the Zygote docs here. Note that in order for this to work, all the Grads
must share the same keys (so they must correspond to the same parameters).
Upvotes: 1