logankilpatrick
logankilpatrick

Reputation: 14521

Built in Flux.jl method to fetch NaN's

One common thing I run into a lot is my model will contain matrices which have NaN values. Is there a common Flux method I can pass my matrices into and detect these NaN's? I know Julia has a built in isnan() function which could be used in some cases but I am not sure if there is a Flux specific version?

Upvotes: 0

Views: 261

Answers (1)

darsnack
darsnack

Reputation: 925

No, there is no Flux-specific function. Using any(isnan, A) is probably what you want to do in most cases. One Flux-related "enhancement" would be to use training loop callbacks to stop training if NaNs are detected.

# assumes (x, y) is your training data
#  and loss(x, y, mode) will compute the loss of model on (x, y)
cb = () -> isnan(loss(x, y, model)) && Flux.stop()

# basic train loop
# assuming opt is your optimizer
Flux.train!((x, y) -> loss(x, y, model), params(model), [(x, y)], opt; cb = cb)

The above example is the basic idea, and you can extended to checking different arrays for NaN. For example, you could do

cb = () -> any(params(m)) do p
    any(isnan, p)
end && Flux.stop()

to check if any parameter is NaN.

Upvotes: 1

Related Questions