corvo
corvo

Reputation: 724

Custom Loss in keras using wrapper throwing "referenced before assignment" error

I am trying to use a wrapper to include the network inputs into loss computation. I am passing the input as model.input to the loss wrapper function. The custom loss function with wrapper is:

def custom_loss_wrapper(input_tensor):
    def custom_loss(y_true, y_pred):
        y_pred = K.print_tensor(y_pred, message="y_pred - ")
        input_tensor_2 = input_tensor
        input_tensor = K.print_tensor(input_tensor, message="input_tensor - ")

        y_true_1 = [0.1, 0.2]
        y_true_2 = [0.3, 0.2]
        
        
        
        bool_2 = input_tensor - tf.constant([1], dtype="float32")
        bool_2 = K.print_tensor(bool_2, message="bool_2 - ")
        bool_1 = tf.constant([2], dtype="float32") - input_tensor
        bool_1 = K.print_tensor(bool_1, message="bool_1 - ")

        y_true_1_tf = tf.constant([y_true_1], dtype=tf.float32)
        y_true_1_bool = y_true_1_tf * bool_1
        

        y_true_2_tf = tf.constant([y_true_2], dtype=tf.float32)
        y_true_2_bool = y_true_2_tf * bool_2
        

        y_true_custom = y_true_1_bool + y_true_2_bool
        #y_true_custom = K.print_tensor(y_true_custom, message="y_true_custom - ")
        
        
        loss = K.square(y_pred - y_true)
        #loss=K.print_tensor(loss, message="loss - ")
        return loss
    return custom_loss

and the code to compile the model is:

model.compile(loss=custom_loss_wrapper(model.input), optimizer='adam')

this throws the error ->

UnboundLocalError: in user code:

    <ipython-input-3-a2ee38c1577c>:16 custom_loss  *
        input_tensor_2 = input_tensor

    UnboundLocalError: local variable 'input_tensor' referenced before assignment

Specifically, the error is thrown as 2nd line of inner Custom_loss function.

Any help on why this is happening and how to correct it would be appreciated.

Upvotes: 1

Views: 269

Answers (2)

Marcus
Marcus

Reputation: 428

Edited answer: The reason is how you have chosen to name your variables. If you run this code you will encounter the same error:

def test(x):
  def next_level():
    x = x
    print("What to print?", x)
  return next_level

a = test("Print this.")
a()

But if you remove x=x (or change to y=x) this code runs just fine. This is because here you are creating an x locally in the inner function that is not yet declared when you try to access it. If you don't declare x in the inner function, x will refer to what is defined in the outer function.

By making sure that you don't declare input_tensor in the inner function, the input_tensor in the outer scope will be used. This should do the trick:

def custom_loss_wrapper(input_tensor):
  def custom_loss(y_true, y_pred):
    y_pred = K.print_tensor(y_pred, message="y_pred - ")
    input_tensor_2 = K.print_tensor(input_tensor, message="input_tensor - ")

    y_true_1 = [0.1, 0.2]
    y_true_2 = [0.3, 0.2]
    
    
    
    bool_2 = input_tensor_2 - tf.constant([1], dtype="float32")
    bool_2 = K.print_tensor(bool_2, message="bool_2 - ")
    bool_1 = tf.constant([2], dtype="float32") - input_tensor_2
    bool_1 = K.print_tensor(bool_1, message="bool_1 - ")

    y_true_1_tf = tf.constant([y_true_1], dtype=tf.float32)
    y_true_1_bool = y_true_1_tf * bool_1
    

    y_true_2_tf = tf.constant([y_true_2], dtype=tf.float32)
    y_true_2_bool = y_true_2_tf * bool_2
    

    y_true_custom = y_true_1_bool + y_true_2_bool
    #y_true_custom = K.print_tensor(y_true_custom, message="y_true_custom - ")
    
    
    loss = K.square(y_pred - y_true)
    #loss=K.print_tensor(loss, message="loss - ")
    return loss
  return custom_loss

Upvotes: 3

Minh-Long Luu
Minh-Long Luu

Reputation: 2731

This part

def custom_loss_wrapper(input_tensor):
    def custom_loss(y_true, y_pred):
        y_pred = K.print_tensor(y_pred, message="y_pred - ")
        input_tensor_2 = input_tensor

This is nested function. The second function does not know any variables or parameters in the first function if it is not passed. In other words, input_tensor variable is not declared yet in the second function.

Upvotes: 0

Related Questions