Reputation: 14521
One common thing I run into a lot is my model will contain matrices which have NaN values. Is there a common Flux method I can pass my matrices into and detect these NaN's? I know Julia has a built in isnan()
function which could be used in some cases but I am not sure if there is a Flux specific version?
Upvotes: 0
Views: 261
Reputation: 925
No, there is no Flux-specific function. Using any(isnan, A)
is probably what you want to do in most cases. One Flux-related "enhancement" would be to use training loop callbacks to stop training if NaNs are detected.
# assumes (x, y) is your training data
# and loss(x, y, mode) will compute the loss of model on (x, y)
cb = () -> isnan(loss(x, y, model)) && Flux.stop()
# basic train loop
# assuming opt is your optimizer
Flux.train!((x, y) -> loss(x, y, model), params(model), [(x, y)], opt; cb = cb)
The above example is the basic idea, and you can extended to checking different arrays for NaN. For example, you could do
cb = () -> any(params(m)) do p
any(isnan, p)
end && Flux.stop()
to check if any parameter is NaN.
Upvotes: 1