Reputation: 1
I am trying to use mixed-precision training with tf-slim in order to speed up the training of networks and make use of the tensorcores available on my GPUs. I also want to make use of multiple network architectures with pre-trained checkpoints.
An example of what Mixed-Precision training is and how it works can be found at https://devblogs.nvidia.com/mixed-precision-resnet-50-tensor-cores/
The basic idea is to 1. Cast the inputs and to fp16 for the forward and backward pass 2. Cast the values back to fp32 when adjusting loss and weights 3. When using the Loss for the backward pass, multiply it by a loss scale 4. When updating the weights, divide it by the same loss scale
This reduces the memory bandwidth and makes use of the Tensor Cores on Volta and Turing GPUs through the use of fp16.
My problem is that I can't figure out where to put the casts to fp16 and fp32 with tf-slim.
To start the training, I use the train_image_classifier.py script from models.research.slim
Do I need to do the cast within the definition files for the network architectures? Or do I need to apply the changes within the tf.contrib.slim files?
Upvotes: 0
Views: 753
Reputation: 361
NVIDIA's documentation on mixed precision training gives a clear example on how to do this with tensorflow.
Tensorflow has implemented the loss scaling in tf.contrib.mixed_precision.LossScaleOptimizer. From what I understood, it uses the same strategy as discribed in NVIDIA's documentation on mixed precision training.
loss = loss_fn()
opt = tf.AdamOptimizer(learning_rate=...)
# Choose a loss scale manager which decides how to pick the right loss scale
# throughout the training process.
# Use fixed loss scaling factor
loss_scale_manager = tf.contrib.mixed_precision.FixedLossScaleManager(loss_scale)
# Use dynamic loss scaling factor
loss_scale_manager = tf.contrib.mixed_precision.ExponentialUpdateLossScaleManager(init_loss_scale,
incr_every_n_steps)
# Wrap the original optimizer in a LossScaleOptimizer.
loss_scale_optimizer = tf.contrib.mixed_precision.LossScaleOptimizer(opt, loss_scale_manager)
# Call minimize() on the loss scale optimizer.
train_op = loss_scale_optimizer.minimize(loss)
Upvotes: 1