chesschi
chesschi

Reputation: 708

Tensorflow: Is it possible to modify the global step in checkpoints

I am trying to modify the global_step in checkpoints so that I can move the training from one machine to another.

Let's say I was doing training for several days on machine A. Now I bought a new machine B with better graphic card and more GPU memory and would like to move the training from machine A to machine B.

In order to restore the checkpoints in machine B, I have previously specified the global_step in Saver.save in machine A but with smaller batch_size and larger sub_iterations.

batch_size=10
sub_iterations=500
for (...)
    for i in range(sub_iterations):
        inputs, labels = next_batch(batch_size)
        session.run(optimizer, feed_dict={inputs: inputs, labels: labels})

saver = tf.train.Saver()
saver.save(session, checkpoints_path, global_step)

Now I copied all the files including the checkpoints from machine A to machine B. Because machine B has more GPU memory, I can modify the batch_size to a larger value but use fewer sub_iterations.

batch_size=100
sub_iterations=50 # = 500 / (100/10)
for (...)
    for i in range(sub_iterations):
        inputs, labels = next_batch(batch_size)
        session.run(optimizer, feed_dict={inputs: inputs, labels: labels})

However we cannot directly restore the copied checkpoints as global_step is different in machine B. For example, tf.train.exponential_decay will produce incorrect learning_rate as the number of sub_iterations is reduced in machine B.

learning_rate = tf.train.exponential_decay(..., global_step, sub_iterations, decay_rate, staircase=True)

Is it possible to modify the global_step in checkpoints? Or there is an alternative but more appropriate way to handle this situation?

Edit 1

In addition to calculating the learning_rate, I also used the global_step to calculate and reduce the number of iterations.

while i < iterations:
    j = 0

    while j < sub_iterations:
        inputs, labels = next_batch(batch_size)
        feed_dict_train = {inputs: inputs, labels: labels}
        _, gstep = session.run([optimizer, global_step], feed_dict=feed_dict_train)
        if (i == 0) and (j == 0):
            i, j = int(gstep/ sub_iterations), numpy.mod(gstep, sub_iterations)
        j = j + 1
    i = i + 1

And we will start the iterations from new i and j. Please feel free to comment on this as it might not be a good approach to restore the checkpoints and continue training from loaded checkpoints.

Edit 2

In machine A, let's say iterations is 10,000, sub_iterations is 500 and batch_size is 10. So the total number of batches we are aiming to train is 10000x500x10 = 50,000,000. Assume we have trained for several days and global_step becomes 501. So the total number of batches trained is 501x10 = 5010. The remaining number of batches not trained is from 5011 to 50,000,000. If we apply i, j = int(gstep/ sub_iterations), numpy.mod(gstep, sub_iterations), the last trained value of i is 501/500=1 and j is 501%500=1.

Now you have copied all the files including the checkpoints to machine B. Since B has more GPU memory and we can train for more batches for one sub-iteration, we set batch_size to 100, and adjust sub_iterations to 50 but leave iterations as 10000. The total number of batches to train is still 50,000,000. So the problem comes, how can we start and train the batches from 5011 to 50,000,000 and do not train again for first 5010 samples?

In order to start and train the batches from 5011 in machine B, we should set i to 1 and j to 0 because the total batches it has trained will be (1*50+0)*100 = 5,000 which is close to 5,010 (as the batch_size is 100 in machine B as opposed to 10 in machine A, we cannot start exactly from 5,010 and we can either choose 5,000 or 5,100).

If we do not adjust the global_step (as suggested by @coder3101), and use back i, j = int(gstep/ sub_iterations), numpy.mod(gstep, sub_iterations) in machine B, i will become 501/50=10 and j will become 501%50=1. So we will start and train from batch 50,100 (=501*batch_size=501*100) which is incorrect (not close to 5010).

This formula i, j = int(gstep/ sub_iterations), numpy.mod(gstep, sub_iterations) is introduced because if we stop the training in machine A at one point, we can restore the checkpoints and continue the training in machine A using this formula. However it seems this formula is not applicable when we move the training from machine A to machine B. Therefore I was hoping to modify the global_step in checkpoints to deal with this situation, and would like to know if this is possible.

Upvotes: 2

Views: 4593

Answers (2)

chesschi
chesschi

Reputation: 708

Yes. It is possible.

To modify the global_step in machine B, you have to perform the following steps:

Calculate the corresponding global_step

In the above example, global_step in machine A is 501 and the total number of trained batches is 501x10=5010. So the corresponding global_step in machine B is 5010/100=50.

Modify the checkpoint filenames

Modify the suffix of model_checkpoint_path and all_model_checkpoint_paths values to use correct step number (e.g. 50 in the above example) in /checkpoint.

Modify the filename of index, meta and data for the checkpoints to use correct step number.

Override global_step after restoring checkpoint

saver.restore(session, model_checkpoint_path)
initial_global_step = global_step.assign(50)
session.run(initial_global_step)
...
# do the training

Now if you restore the checkpoints (including global_step) and override global_step, the training will use the updated global_step and adjust i and j and learning_rate correctly.

Upvotes: 6

coder3101
coder3101

Reputation: 4165

If you want to keep decayed learning rate same on Machine B as it was on A, you can tweak other parameters of the function tf.train.exponential_dacay . For Example, you can change decay rate on the new machine.

In order to find that decay_rate you need to know how exponential_decay is computed.

decay_learning_rate = learning_rate * decay_rate ^ (global_step/decay_steps)

global_step/decay_step yields an pure integer if staircase==True

Where decay_steps is your sub_iteration and global_step is from Machine A.

You can change the decay_rate in such a way that new learning rate on machine B is same or close to what you would expect from machine A at that global_step.

Also, you can change the initial learning rate for machine B so as to achieve uniform exponential rate decay from machine A to B.

So in short while importing from A to B you have changed 1 variable (sub_iteration), and kept global_step same. You can adjust the other 2 variables of exponential_decay(..)in such a way that your output learning rate from the function is same as you would expect from Machine A at that global_step.

Upvotes: 1

Related Questions