GoTN
GoTN

Reputation: 135

Access layer attribute in custom loss function in Keras

I want to write a custom loss function in Keras which depends on an attribute of a (custom) layer in the network.

The idea is the following:

Some example code to make it more clear:

import numpy as np
from keras import losses, layers, models

class MyLayer(layers.Layer):
    def call(self, x):
        a = np.random.rand()
        self.a = a # <-- does this work as expected?
        return x+a

def my_loss(layer):
    def modified_loss(y_true, y_pred):
        a = layer.a
        y_true = y_true + a
        return losses.mse(y_true, y_pred)

input_layer = layers.Input()
my_layer = MyLayer(input_layer, name="my_layer")
output_layer = layers.Dense(4)(my_layer)
model = models.Model(inputs=input_layer, outputs=output_layer)
model.compile('adam', my_loss(model.get_layer("my_layer")))

I expect that a is changing for every batch and that the same a is used in the layer and loss function. Right now, it is not working the way I intended. It seems like the a in the loss function is never updated (and maybe not even in the layer).

How do I change the attribute/value of a in the layer at every call and access it in the loss function?

Upvotes: 2

Views: 779

Answers (1)

Stewart_R
Stewart_R

Reputation: 14505

Not quite sure I am following the purpose on this (and I am bothered by the call to np inside the call() of your custom layer - could you not use the tf.random functions instead?) but you can certainly access the a property inside your loss function.

Perhaps something like:

class MyLayer(layers.Layer):
    def call(self, x):
        a = np.random.rand() # FIXME --> use tf.random
        self.a = a
        return x+a

input_layer = layers.Input()
my_layer = MyLayer(input_layer, name="my_layer")
output_layer = layers.Dense(4)(my_layer)
model = models.Model(inputs=input_layer, outputs=output_layer)

def my_loss(y_true, y_pred):
  y_true = y_true + my_layer.a
  return losses.mse(y_true, y_pred)


model.compile('adam', loss=my_loss)

Upvotes: 1

Related Questions