JAIME GOMEZ
JAIME GOMEZ

Reputation: 23

How do I update pixelClassificationLayer() to a custom loss function?

I have seen in the Mathworks official website for the pixelClassificationLayer() function that I should update it to a custom loss function using the following code:

function loss = modelLoss(Y,T) 
  mask = ~isnan(T);
  targets(isnan(T)) = 0;
  loss = crossentropy(Y,T,Mask=mask,NormalizationFactor="mask-included"); 
end

netTrained = trainnet(images,net,@modelLoss,options); 

However, I can't see any mention of the inputs 'Classes' or 'ClassWeights', which I'm currently using to define the custom pixelClassificationLayer: pixelClassificationLayer('Classes',classNames,'ClassWeights',classWeights), where classNames is a vector containing the names of each class as a string and classWeights is a vector containing the weights of each class to balance classes when there are underrepresented classes in the training data.

How can I include these parameters in my custom loss function?

Upvotes: 2

Views: 35

Answers (1)

seralouk
seralouk

Reputation: 33147

You need to explicitly account for these parameters within your custom loss function.

Below an example, but adjust accordingly:

function loss = modelLoss(Y, T, classNames, classWeights)

    % normalized to 1
    classWeights = classWeights / sum(classWeights);

    mask = ~isnan(T);
    T(isnan(T)) = 0;

    numClasses = numel(classNames);
    T_onehot = zeros([size(T, 1), size(T, 2), numClasses, size(T, 4)], 'like', Y);
    for i = 1:numClasses
        T_onehot(:, :, i, :) = (T == i);
    end

    % class-wise weighted cross-entropy
    weightedLoss = 0;
    for c = 1:numClasses
        classMask = mask & (T == c);
        weightedLoss = weightedLoss + classWeights(c) * crossentropy(Y(:, :, c, :), T_onehot(:, :, c, :), Mask=classMask);
    end

    % Normalize by # of valid pixels
    numValidPixels = sum(mask(:));
    loss = weightedLoss / max(numValidPixels, 1);
end


% Define weights
classNames = [...];
classWeights = [...]; % Example weights

customLoss = @(Y, T) modelLoss(Y, T, classNames, classWeights);

netTrained = trainnet(images, net, customLoss, options);

Upvotes: 1

Related Questions