fwalch
fwalch

Reputation: 1146

TensorFlow: How to set learning rate decay based on epochs?

The learning rate decay function tf.train.exponential_decay takes a decay_steps parameter. To decrease the learning rate every num_epochs, you would set decay_steps = num_epochs * num_train_examples / batch_size. However, when reading data from .tfrecords files, you don't know how many training examples there are inside them.

To get num_train_examples, you could:

However, this isn't very elegant.

Is there an easier way to either get the number of training examples from a .tfrecords file or set the learning rate decay based on epochs instead of steps?

Upvotes: 1

Views: 5026

Answers (3)

Lerner Zhang
Lerner Zhang

Reputation: 7130

I recommend you set the learning rate decay according to the changes of the training or evaluation loss. If the loss is oscillating you can decrease the learning rate. Hardly can you predict from which epoch or step you should decrease it before the training starts.

Upvotes: 0

Aakash Saxena
Aakash Saxena

Reputation: 309

In the learning_rate below,

learning_rate = tf.train.exponential_decay(starter_learning_rate, global_step,
                                           100000, 0.96, staircase=True)

starter_learning_rate can be changed after desired epochs by defining a function like:

def initial_learning_rate(epoch):
    if (epoch >= 0) and (epoch < 100):
        return 0.1
    if (epoch >= 100) and (epoch < 200):
        return 0.05
    if (epoch >= 200) and (epoch < 500):
        return 0.001

And then you may initialize your starter_learning_rate inside the for loop (iterating over epochs) as:

for epoch in range(epochs): #epochs is the total number of epochs
starter_learning_rate = initial_learning_rate(epoch)
...

Note

The global_step variable is not changed in:

decayed_learning_rate = starter_learning_rate *
                        decay_rate ^ (global_step / decay_steps)

Upvotes: 0

keveman
keveman

Reputation: 8487

You can use the following code to get the number of records in a .tfrecords file :

def get_num_records(tf_record_file):
  return len([x for x in tf.python_io.tf_record_iterator(tf_record_file)])

Upvotes: 3

Related Questions