Reputation: 11132
I have a complex model that I would like to train in mixed precision. To do this, I use the torch.amp package. I can enable AMP for the whole model using with torch.cuda.amp.autocast(enabled=enable_amp, dtype=torch.float16):
. However, the model training is not stable, so I would like to force certain areas of the model to float32.
Here's what I've tried or considered:
There are two officially endorsed solutions I'm aware of: disable AMP for a block and cast all input tensors at the start of the block, or use custom_fwd
as described in this answer. However, these both have issues. The first requires manually casting each input tensor to float32. The second requires adding the custom_fwd
decorator to a forward function, so I need to either add that to each module individually or make a new container module that holds the other modules. Neither of those solutions works well for me since I want to test enabling and disabling float16 for many different parts of my model, so I would need to be constantly adding and removing code to cast dozens of tensors and/or modules.
What I want is the ability to cast to float32 for an entire block of code like with [cast everything to float32]:
, but I don't know a way to do that reliably. with torch.cuda.amp.autocast(enabled=False):
doesn't cast float16 tensors to float32, it only disables casting float32 tensors to float16. with torch.cuda.amp.autocast(dtype=torch.float32):
appears to maybe work, but that's not officially documented usage, and based on the docs I'm not confident that it will reliably work in the future. This model will continue to be used/updated for years by a team of people, so I don't want to risk it breaking in the future if a pytorch update changes undocumented functionality.
Does anyone know of a way to reliably cast everything in a whole block of code to float32?
Upvotes: 1
Views: 1071