Baunza
Baunza

Reputation: 103

keras: how to use learning rate decay with model.train_on_batch()

In my current project I'm using keras' train_on_batch() function to train since the fit() function does not support the alternating training of generator and discriminator required for GAN's. Using (for example) the Adam optimizer I have to specify the learning rate decay in the constructor with optimizer = Adam(decay=my_decay) and hand this to the models compiling method. This work fine if I use the model's fit() function afterwards, since that takes care of counting the training repetitions internally, but I don't know how I can set this value myself using a construct like

counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        loss = model.train_on_batch(x, y)    # how to use the current learning rate?

with some function to calculate the learning rate. How can i set the current learning rate manually?

If there are mistakes in this post I'm sorry, it's my first question here.

Thank you already for any help.

Upvotes: 10

Views: 4065

Answers (1)

Mikhail Stepanov
Mikhail Stepanov

Reputation: 3790

EDIT

In 2.3.0, lr was renamed to learning_rate: link. In older versions you should use lr instead (thanks @Bananach).

Set value with a help of keras backend: keras.backend.set_value(model.optimizer.learning_rate, learning_rate) (where learning_rate is a float, desired learning rate) works for the fit method and should work for the train_on_batch:

from keras import backend as K


counter = 0
for epoch in range(EPOCHS):
    for batch_idx in range(0, number_training_samples, BATCH_SIZE):
        # get training batch:
        x = ...
        y = ...
        # calculate learning rate:
        current_learning_rate = calculate_learning_rate(counter)
        # train model:
        K.set_value(model.optimizer.learning_rate, current_learning_rate)  # set new learning_rate
        loss = model.train_on_batch(x, y) 

Hope it helps!

Upvotes: 18

Related Questions