logankilpatrick
logankilpatrick

Reputation: 14521

How to define a custom loss function in Flux.jl?

Looking at the Flux.jl docs, I see there a ton of built in loss functions: https://fluxml.ai/Flux.jl/stable/models/losses/. My question is how can I define and use my own loss function in Flux if I want something more esoteric?

Upvotes: 0

Views: 530

Answers (1)

Matěj Račinský
Matěj Račinský

Reputation: 1804

You can use any differentiable function which returns a single float value as your loss, as stated in the comment above, the prepared functions are just for your convenience. You can pass anything e.g.

using Flux
yourcustomloss(ŷ, y) = sum(.- sum(y .* logsoftmax(ŷ), dims = 1))

and calculate the gradient of it or pass it to train! function.

Upvotes: 1

Related Questions