bitWise
bitWise

Reputation: 709

How to train a combination of models in Flux?

I am trying to build a deep learning model in julia. I have two models m1 and m2 which are neural networks. Here is my code:

using Flux

function even_mask(x)
    s1, s2 = size(x)
    weight_mask = zeros(s1, s2)
    weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
    return weight_mask
end

function odd_mask(x)
    s1, s2 = size(x)
    weight_mask = zeros(s1, s2)
    weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
    return weight_mask
end

function even_duplicate(x)
    s1, s2 = size(x)
    x_ = zeros(s1, s2)
    x_[1:2:s1,:] = x[1:2:s1,:]
    x_[2:2:s1,:] = x[1:2:s1,:]
    return x_
end

function odd_duplicate(x)
    s1, s2 = size(x)
    x_ = zeros(s1, s2)
    x_[1:2:s1,:] = x[2:2:s1,:]
    x_[2:2:s1,:] = x[2:2:s1,:]
    return x_
end

function Even(m)
    x -> x .+ even_mask(x).*m(even_duplicate(x))
end

function InvEven(m)
    x -> x .- even_mask(x).*m(even_duplicate(x))
end

function Odd(m)
    x -> x .+ odd_mask(x).*m(odd_duplicate(x))
end

function InvOdd(m)
    x -> x .- odd_mask(x).*m(odd_duplicate(x))
end

m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))

forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))

function loss(x)
    z = forward(x)
    return 0.5*sum(z.*z)
end

opt = Flux.ADAM()

x = rand(4,100)

for i=1:100
    Flux.train!(loss, Flux.params(forward), x, opt)
    println(loss(x))
end

The forward model is a combination of m1 and m2. I need to optimize m1 and m2 so I could optimize both forward and inverse models. But it seems that params(forward) is empty. How could I train my model?

Upvotes: 4

Views: 744

Answers (1)

contradict
contradict

Reputation: 161

I don't think plain functions can be used as layers in Flux. You need to use the @functor macro to add the extra functionality to collect parameters: https://fluxml.ai/Flux.jl/stable/models/basics/#Layer-helpers-1

In your case, rewriting Even, InvEven, Odd and InvOdd like this should help:

struct Even
    model
end

(e::Even)(x) = x .+ even_mask(x).*e.model(even_duplicate(x))

Flux.@functor Even

After adding this definition,

Flux.params(Even(m1))

Should return a non-empty list

EDIT

An even simpler way to implement Even and friends is to use the built-in SkipConnection layer:

 Even(m) = SkipConnection(Chain(even_duplicate, m),                         
                          (mx, x) -> x .+ even_mask(x) .* mx)

I suspect this is a version difference, but with Julia 1.4.1 and Flux v0.10.4, I get the error BoundsError: attempt to access () at index [1] when running your training loop, I need to replace the data with

x = [(rand(4,100), 0)]

Otherwise the loss is applied to each entry in the array x. since train! splats loss over x.

The next error mutating arrays is not supported is due to the implementation of *_mask and *_duplicate. These functions construct an array of zeros and then mutate it by replacing values from the input.

You can use Zygote.Buffer to implement this code in a way that can be differentiated.

using Flux
using Zygote: Buffer

function even_mask(x)
    s1, s2 = size(x)
    weight_mask = Buffer(x)
    weight_mask[2:2:s1,:] = ones(Int(s1/2), s2)
    weight_mask[1:2:s1,:] = zeros(Int(s1/2), s2)
    return copy(weight_mask)
end

function odd_mask(x)
    s1, s2 = size(x)
    weight_mask = Buffer(x)
    weight_mask[2:2:s1,:] = zeros(Int(s1/2), s2)
    weight_mask[1:2:s1,:] = ones(Int(s1/2), s2)
    return copy(weight_mask)
end

function even_duplicate(x)
    s1, s2 = size(x)
    x_ = Buffer(x)
    x_[1:2:s1,:] = x[1:2:s1,:]
    x_[2:2:s1,:] = x[1:2:s1,:]
    return copy(x_)
end

function odd_duplicate(x)
    s1, s2 = size(x)
    x_ = Buffer(x)
    x_[1:2:s1,:] = x[2:2:s1,:]
    x_[2:2:s1,:] = x[2:2:s1,:]
    return copy(x_)
end

Even(m) = SkipConnection(Chain(even_duplicate, m),
                         (mx, x) -> x .+ even_mask(x) .* mx)

InvEven(m) = SkipConnection(Chain(even_duplicate, m),
                            (mx, x) -> x .- even_mask(x) .* mx)

Odd(m) = SkipConnection(Chain(odd_duplicate, m),
                        (mx, x) -> x .+ odd_mask(x) .* mx)

InvOdd(m) = SkipConnection(Chain(odd_duplicate, m),
                           (mx, x) -> x .- odd_mask(x) .* mx)

m1 = Chain(Dense(4,6,relu), Dense(6,5,relu), Dense(5,4))
m2 = Chain(Dense(4,7,relu), Dense(7,4))

forward = Chain(Even(m1), Odd(m2))
inverse = Chain(InvOdd(m2), InvEven(m1))

function loss(x, y)
    z = forward(x)
    return 0.5*sum(z.*z)
end

opt = Flux.ADAM(1e-6)

x = [(rand(4,100), 0)]

function train!()
    for i=1:100
        Flux.train!(loss, Flux.params(forward), x, opt)
        println(loss(x[1]...))
    end
end

At this point, you get to the real fun of deep networks. After one training step, the training diverges to NaN with the default learning rate. Reducing the initial training rate to 1e-6 helps, and the loss looks like it is decreasing.

Upvotes: 4

Related Questions