Reputation: 61
Suppose we have a function f whose gradient is slow to compute, and two functions g1
and g2
whose gradient is easy to compute. In pytorch
, how can I calculate the gradients of z1 = g1(f(x))
and z2 = g2(f(x))
with respect to x, without having to calculate the gradient of f twice?
Example:
import torch
import time
def slow_fun(x):
A = x*torch.ones((1000,1000))
B = torch.matrix_exp(1j*A)
return torch.real(torch.trace(B))
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z1 = y**2
z2 = torch.sqrt(y)
start = time.time()
z1.backward(retain_graph = True)
end = time.time()
print("dz1/dx: ", x.grad)
print("duration: ", end-start, "\n")
x.grad = None
start = time.time()
z2.backward(retain_graph = True)
end = time.time()
print("dz2/dx: ", x.grad)
print("duration: ", end-start, "\n")
This prints
dz1/dx: tensor(-1673697.1250)
duration: 1.5571658611297607
dz2/dx: tensor(-13.2334)
duration: 1.3989012241363525
so calculating dz2/dx
takes about as long as calculating dz1/dx
.
The calculating of dz2/dx
could be sped up if pytorch
would store dy/dx
during the calculation of dz1/dx
, and then reuse that result during the calculation of dz2/dx
.
Is there a mechanism built into pytorch
to achieve such a behavior?
Upvotes: 4
Views: 201
Reputation: 40768
To complement the answer of @Karl, here is an alternative solution that doesn't require detaching tensors. It uses torch.autograd.functional.jacobian
to compute and extract the gradient. Here again, we compute dy/dx
once and use its computation via the chain rule.
# compute dy/dx
dy_dx = jacobian(slow_fun, x)
# compute dz1/dx
dz1_dy = jacobian(lambda y: y**2, y)
dz1_dx = dy_dx*dz1_dy
# compute dz2/dx
dz2_dy = jacobian(lambda y: y.sqrt(), y)
dz2_dx = dy_dx*dz2_dy
Upvotes: 0
Reputation: 5473
You can decouple the nested function using chain rule. However, there will be some differences due to numerical issues.
import torch
import time
def slow_fun(x):
A = x*torch.ones((1000,1000))
B = torch.matrix_exp(1j*A)
return torch.real(torch.trace(B))
# baseline z1
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z1 = y**2
z1.backward()
dz1_dx = x.grad
print(dz1_dx)
> tensor(-1648274.2500)
# baseline z2
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z2 = torch.sqrt(y)
z2.backward()
dz2_dx = x.grad
print(dz2_dx)
> tensor(-13.1979)
# compute just dy/dx
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
y.backward()
dy_dx = x.grad
# detach y to prevent full backprop
y1 = y.detach()
y1.requires_grad = True
z1 = y1**2
z1.backward()
dz1_dy = y1.grad
# compute gradient with chain rule
dz1_dx = dz1_dy * dy_dx
print(dz1_dx)
> tensor(-1672148.5000)
# detach y to prevent full backprop
y2 = y.detach()
y2.requires_grad = True
z2 = torch.sqrt(y2)
z2.backward()
dz2_dy = y2.grad
# compute gradient with chain rule
dz2_dx = dz2_dy * dy_dx
print(dz2_dx)
> tensor(-13.1980)
Upvotes: 0