Valeria
Valeria

Reputation: 1220

Julia Flux: determine type of layer

I am new to Julia and I am having trouble determining the type of the layer in the Flux's model. For the purpose of example, imagine that my model is just one neuron:

using Flux
m = Chain(Dense(1, 1, sigmoid))

I want to iterate my Chain and, depending on the type of the layer, make different actions (specifically, I want to add regularization for the Dense layers).

I come to Julia from Python, and my first guess was to compare the type of the layer to type of Dense. Contrary to my intuition, this gives me false:

for layer in m
    println(typeof(layer) == typeof(Dense))
end
  1. Why this does not work in Julia?
  2. What is the proper way to make it work of Julia? Of course, I can check if the specific fields of the struct (in/out/sigmoid in case of Dense) exist for the given layer, but there would be no guarantee that it is not some other layer with analogous fields.

Upvotes: 2

Views: 360

Answers (2)

Bogumił Kamiński
Bogumił Kamiński

Reputation: 69909

Use layers property of m instead and to check if a value is of a given type use isa instead. In summary this should work:

for layer in m.layers
    if layer isa Dense
        # do something with dense layer
    else
        # do something else
    end
end

EDIT: indeed m supports iteration and indexing, which I did not know, so as @darsnack suggested this is enough:

for layer in m
    if layer isa Dense
        # do something with dense layer
    else
        # do something else
    end
end

Now to clarify type checking:

  • if you have a value, and you want to check if its type is a subtype of a given type use isa as I have above
  • if you have two types you want to compare for subtyping use <:, so you could have written typeof(layer) <: Dense; for types == checks are not recommended, see this warning in the Julia manual

You can check out this section of the Julia manual to read more about it

Upvotes: 5

darsnack
darsnack

Reputation: 925

Iterating with for layer in m should be fine. The reason you get false is because typeof(Dense) == UnionAll. You should change your code to:

for layer in m
    println(typeof(layer) == Dense))
end

A more Julian approach is to dispatch on the layer type like so:

function processlayer(layer::Dense)
    # do special thing for dense
end

function processlayer(layer)
    # do other thing for anything else
end

for layer in m
    processlayer(layer)
end

Upvotes: 2

Related Questions