Reputation: 4630
I'm trying to save memory while training a model that uses single precision weights by doing the calculations in half precision.
I tried using autocast, and the model does prediction in half precision as it should. However the gradient produced is still in single precision. This ruins both performance and memory savings. Is there any way to instruct torch to calculate grads in half precision and use those to update the original single precision weights?
import torch
class KekNet (torch.nn.Module):
def __init__(self):
super(KekNet, self).__init__()
self.layer1 = torch.nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), dtype=torch.float32)
def forward(self, x, features=False):
return self.layer1(x)
device = torch.device("cuda")
# HALF-DATA AUTOCAST
net = KekNet().to(device)
loss_l2 = torch.nn.MSELoss(reduction='none')
g_params = [{'params': net.parameters(), 'weight_decay': 0}]
optimizerG = torch.optim.RMSprop(g_params, lr=3e-5, alpha=0.99, eps=1e-07, weight_decay=0)
schedulerG = torch.optim.lr_scheduler.CosineAnnealingLR(optimizerG, T_max=300)
X = torch.randn((40,3,555,555), dtype=torch.float16, device =device)
with torch.autocast(device_type='cuda', dtype=torch.float16):
Y_h=net(X)
Y = torch.randn_like(Y_h)
loss = loss_l2(Y_h, Y).mean()
loss.backward()
print(f"-autocast\r\ndata precision: {X.dtype}\r\npred precision: {Y_h.dtype}\r\ngrad precision: {net.layer1.weight.grad.dtype}\r\n")
optimizerG.step()
schedulerG.step()
results in following:
data precision: torch.float16
pred precision: torch.float16
grad precision: torch.float32
Upvotes: -1
Views: 1055
Reputation: 5363
Autocast doesn't transform the weights of the model, so weight grads will have the same dtype as the weights. You can try manually calling .half()
on the model to change this. I'm not sure if there's a way to compute grads in fp16 while keeping the weights in fp32.
import torch
import torch.nn as nn
torch.set_default_device('cuda')
model = nn.Linear(8,1)
opt = torch.optim.SGD(model.parameters(), lr=1e-3)
x = torch.randn(12, 8, dtype=torch.float16)
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
output = model(x)
loss = output.mean()
loss.backward()
print(model.weight.grad.dtype)
# > torch.float32
opt.zero_grad()
model.half()
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
output = model(x)
loss = output.mean()
loss.backward()
print(model.weight.grad.dtype)
# > torch.float16
Additionally, some ops have numerical stability issues when computed in fp16. To avoid this, pytorch autocasts certain opts to fp32. You can find the full list here.
In your case, MSE loss (and really the pow
function) autocast to fp32. This won't change the weight grad dtype in the example above, but worth noting if you see fp32 cropping up other places.
Upvotes: 2