logankilpatrick
logankilpatrick

Reputation: 14521

How to define a custom training loop in Flux.jl

I am trying to setup my training loop for a ML workflow using Flux.jl. I know I can use the built in Flux.train!() function to do the training but I need a little bit more customization than the API gives me out of the box. How can I define my own custom training loop in Flux?

Upvotes: 1

Views: 413

Answers (1)

logankilpatrick
logankilpatrick

Reputation: 14521

Per the Flux.jl docs on Training Loops, you can do something like:

function my_custom_train!(loss, ps, data, opt)
  # training_loss is declared local so it will be available for logging outside the gradient calculation.
  local training_loss
  ps = Params(ps)
  for d in data
    gs = gradient(ps) do
      training_loss = loss(d...)
      # Code inserted here will be differentiated, unless you need that gradient information
      # it is better to do the work outside this block.
      return training_loss
    end
    # Insert whatever code you want here that needs training_loss, e.g. logging.
    # logging_callback(training_loss)
    # Insert what ever code you want here that needs gradient.
    # E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
    update!(opt, ps, gs)
    # Here you might like to check validation set accuracy, and break out to do early stopping.
  end
end

It is also possible to simplify the above example with a hardcoded loss function.

Upvotes: 2

Related Questions