zucchinifries
zucchinifries

Reputation: 603

How to create a loss function parameter that is dependent on epoch number in Keras?

I have a custom loss function with a hyperparameter alpha that I want to change every 20 epochs over training. The loss function is something like:

def custom_loss(x, x_pred): 
    loss1 = binary_crossentropy(x, x_pred)
    loss2 = (x, x_pred)
    return (alpha)* loss1 + (1-alpha)*loss2

From my research, creating a custom callback is the way to go. I have looked at the solution for a similar question here and here but the solutions do not implement the callback solution which is what I want to accomplish.

I've attempted to create a custom callback by altering the LearningRateScheduler callback from the keras repo

class changeAlpha(Callback):
    def __init__(self, alpha):
        super(changeAlpha, self).__init__()
        self.alpha = alpha 

    def on_epoch_begin(self, epoch, logs={}):
        if epoch%20 == 0:   
             K.set_value(self.alpha, K.get_value(self.alpha) * epoch**0.95)
             print("Setting alpha to =", str(alpha))

However, I am not certain the alpha value actually corresponds to the alpha value in my loss function. In any case, when I put the changeAlpha callback in the model.fit method, I receive an attribute error.

Can someone help me edit the callback such that it alters my alpha parameter after a certain number of epochs?

Upvotes: 7

Views: 2436

Answers (1)

NormanZhu
NormanZhu

Reputation: 121

I understood your idea. I think that the problem is that the alpha in the loss function is not referred to the member of changeAlpha class. You can try like this:

instance = changeAlpha()
def custom_loss(x, x_pred): 
    loss1 = binary_crossentropy(x, x_pred)
    loss2 = (x, x_pred)
    return (instance.alpha*)* loss1 + (1-instance.alpha)*loss2

Or, you can make the alpha as the class variable not the install variable, and then change the loss function as below:

def custom_loss(x, x_pred): 
    loss1 = binary_crossentropy(x, x_pred)
    loss2 = (x, x_pred)
    return (changeAlpha.alpha*)* loss1 + (1-changeAlpha.alpha)*loss2

Hope it can help you.

Upvotes: 4

Related Questions