user2299067
user2299067

Reputation: 87

Pytorch Custom Loss Function with If Statement

I am trying to create a custom loss function in Pytorch that evaluates each element of a tensor with an if statement and acts accordingly.

def my_loss(outputs,targets,fin_val):
    if (outputs.detach()-fin_val.detach())*(targets.detach()-fin_val.detach())<0:
        loss=3*(outputs-targets)**2
    else:
        loss=0.3*(outputs-targets)**2
    return loss

I have also tried:

def my_loss(outputs,targets,fin_val):
    if torch.gt((outputs.detach()-fin_val.detach())*(targets.detach()-fin_val.detach()),0):
        loss=0.3*(outputs-targets)**2
    else:
        loss=3*(outputs-targets)**2
    return loss

In both cases, I get the following error:

RuntimeError: Boolean value of Tensor with more than one value is ambiguous

TIA

Upvotes: 2

Views: 1326

Answers (1)

Ivan
Ivan

Reputation: 40638

You are getting this error because the condition you are passing to the if statement is not a boolean but a tensor of booleans. Just check what's the nature of (outputs.detach()-fin_val.detach())*(targets.detach()-fin_val.detach())<0, it is a tensor!

What you should be looking to do instead is handling this in vectorized form. You can use torch.where which is built for this use:

torch.where(condition=(outputs - fin_val)*(targets - fin_val) < 0,
            x=3*(outputs-targets)**2,
            y=0.3*(outputs-targets)**2)

This will return a tensor of "xs" and "ys" based on the point-wise condition tensor condition. Then, you could average it depending on your needs to get an actual loss value.

Upvotes: 4

Related Questions