Reputation: 709
I am trying to build a deep learning model in julia. I have two models m1 and m2 which are neural networks. Here is my code:
using Flux
function even_mask(x)
s1, s2 = size(x)
weight_mask = zeros(s1, s2)
weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
return weight_mask
end
function odd_mask(x)
s1, s2 = size(x)
weight_mask = zeros(s1, s2)
weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
return weight_mask
end
function even_duplicate(x)
s1, s2 = size(x)
x_ = zeros(s1, s2)
x_[1:2:s1,:] = x[1:2:s1,:]
x_[2:2:s1,:] = x[1:2:s1,:]
return x_
end
function odd_duplicate(x)
s1, s2 = size(x)
x_ = zeros(s1, s2)
x_[1:2:s1,:] = x[2:2:s1,:]
x_[2:2:s1,:] = x[2:2:s1,:]
return x_
end
function Even(m)
x -> x .+ even_mask(x).*m(even_duplicate(x))
end
function InvEven(m)
x -> x .- even_mask(x).*m(even_duplicate(x))
end
function Odd(m)
x -> x .+ odd_mask(x).*m(odd_duplicate(x))
end
function InvOdd(m)
x -> x .- odd_mask(x).*m(odd_duplicate(x))
end
m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))
forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))
function loss(x)
z = forward(x)
return 0.5*sum(z.*z)
end
opt = Flux.ADAM()
x = rand(4,100)
for i=1:100
Flux.train!(loss, Flux.params(forward), x, opt)
println(loss(x))
end
The forward model is a combination of m1 and m2. I need to optimize m1 and m2 so I could optimize both forward and inverse models. But it seems that params(forward) is empty. How could I train my model?
Upvotes: 4
Views: 744
Reputation: 161
I don't think plain functions can be used as layers in Flux. You need to use the @functor
macro to add the extra functionality to collect parameters: https://fluxml.ai/Flux.jl/stable/models/basics/#Layer-helpers-1
In your case, rewriting Even
, InvEven
, Odd
and InvOdd
like this should help:
struct Even
model
end
(e::Even)(x) = x .+ even_mask(x).*e.model(even_duplicate(x))
Flux.@functor Even
After adding this definition,
Flux.params(Even(m1))
Should return a non-empty list
EDIT
An even simpler way to implement Even
and friends is to use the built-in SkipConnection layer:
Even(m) = SkipConnection(Chain(even_duplicate, m),
(mx, x) -> x .+ even_mask(x) .* mx)
I suspect this is a version difference, but with Julia 1.4.1 and Flux v0.10.4, I get the error BoundsError: attempt to access () at index [1]
when running your training loop, I need to replace the data with
x = [(rand(4,100), 0)]
Otherwise the loss is applied to each entry in the array x
. since train!
splats loss
over x
.
The next error mutating arrays is not supported
is due to the implementation of *_mask
and *_duplicate
. These functions construct an array of zeros and then mutate it by replacing values from the input.
You can use Zygote.Buffer to implement this code in a way that can be differentiated.
using Flux
using Zygote: Buffer
function even_mask(x)
s1, s2 = size(x)
weight_mask = Buffer(x)
weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
weight_mask[1:2:s1,:] = zeros(Int(s1/2), s2)
return copy(weight_mask)
end
function odd_mask(x)
s1, s2 = size(x)
weight_mask = Buffer(x)
weight_mask[2:2:s1,:] = zeros(Int(s1/2), s2)
weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
return copy(weight_mask)
end
function even_duplicate(x)
s1, s2 = size(x)
x_ = Buffer(x)
x_[1:2:s1,:] = x[1:2:s1,:]
x_[2:2:s1,:] = x[1:2:s1,:]
return copy(x_)
end
function odd_duplicate(x)
s1, s2 = size(x)
x_ = Buffer(x)
x_[1:2:s1,:] = x[2:2:s1,:]
x_[2:2:s1,:] = x[2:2:s1,:]
return copy(x_)
end
Even(m) = SkipConnection(Chain(even_duplicate, m),
(mx, x) -> x .+ even_mask(x) .* mx)
InvEven(m) = SkipConnection(Chain(even_duplicate, m),
(mx, x) -> x .- even_mask(x) .* mx)
Odd(m) = SkipConnection(Chain(odd_duplicate, m),
(mx, x) -> x .+ odd_mask(x) .* mx)
InvOdd(m) = SkipConnection(Chain(odd_duplicate, m),
(mx, x) -> x .- odd_mask(x) .* mx)
m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))
forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))
function loss(x, y)
z = forward(x)
return 0.5*sum(z.*z)
end
opt = Flux.ADAM(1e-6)
x = [(rand(4,100), 0)]
function train!()
for i=1:100
Flux.train!(loss, Flux.params(forward), x, opt)
println(loss(x[1]...))
end
end
At this point, you get to the real fun of deep networks. After one training step, the training diverges to NaN
with the default learning rate. Reducing the initial training rate to 1e-6 helps, and the loss looks like it is decreasing.
Upvotes: 4