Caesar
Caesar

Reputation: 1092

How can I use apex AMP (Automatic Mixed Precision) with model parallelism on Pytorch?

My model has a few LSTMs which run out of Cuda memory when run on large sequences with one GPU. So I shifted a few components of the model to another GPU. I tried 2 things with Apex AMP:

  1. Move the model components to another GPU before invoking amp.initialize. In this case, I get NaNs soon after first backpropagation.

  2. First invoke amp.initialize, and then move the model components to another GPU. In this case, its like the model backpropagation runs on a single GPU. It runs out of Cuda memory.

The model training runs fine without Apex, so I suppose I am missing some step where the loss is backpropagated on both GPUs. I looked through the documentations of Apex, however, it only talks about DataParallelism, and not ModelParallelism.

Any ideas?

Upvotes: 2

Views: 275

Answers (0)

Related Questions