Reputation: 14521
I am trying to setup my training loop for a ML workflow using Flux.jl. I know I can use the built in Flux.train!()
function to do the training but I need a little bit more customization than the API gives me out of the box. How can I define my own custom training loop in Flux?
Upvotes: 1
Views: 413
Reputation: 14521
Per the Flux.jl docs on Training Loops, you can do something like:
function my_custom_train!(loss, ps, data, opt)
# training_loss is declared local so it will be available for logging outside the gradient calculation.
local training_loss
ps = Params(ps)
for d in data
gs = gradient(ps) do
training_loss = loss(d...)
# Code inserted here will be differentiated, unless you need that gradient information
# it is better to do the work outside this block.
return training_loss
end
# Insert whatever code you want here that needs training_loss, e.g. logging.
# logging_callback(training_loss)
# Insert what ever code you want here that needs gradient.
# E.g. logging with TensorBoardLogger.jl as histogram so you can see if it is becoming huge.
update!(opt, ps, gs)
# Here you might like to check validation set accuracy, and break out to do early stopping.
end
end
It is also possible to simplify the above example with a hardcoded loss function.
Upvotes: 2