klpskp
klpskp

Reputation: 61

Calculating two gradients in pytorch and reusing an intermediate gradient

Suppose we have a function f whose gradient is slow to compute, and two functions g1 and g2 whose gradient is easy to compute. In pytorch, how can I calculate the gradients of z1 = g1(f(x)) and z2 = g2(f(x)) with respect to x, without having to calculate the gradient of f twice?

Example:

import torch
import time

def slow_fun(x):
    A = x*torch.ones((1000,1000))
    B = torch.matrix_exp(1j*A)
    return torch.real(torch.trace(B))

x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z1 = y**2
z2 = torch.sqrt(y)

start = time.time()
z1.backward(retain_graph = True)
end = time.time()
print("dz1/dx: ", x.grad)
print("duration: ", end-start, "\n")

x.grad = None
start = time.time()
z2.backward(retain_graph = True)
end = time.time()
print("dz2/dx: ", x.grad)
print("duration: ", end-start, "\n")

This prints

dz1/dx:  tensor(-1673697.1250)
duration:  1.5571658611297607

dz2/dx:  tensor(-13.2334)
duration:  1.3989012241363525

so calculating dz2/dx takes about as long as calculating dz1/dx.

The calculating of dz2/dx could be sped up if pytorch would store dy/dx during the calculation of dz1/dx, and then reuse that result during the calculation of dz2/dx.

Is there a mechanism built into pytorch to achieve such a behavior?

Upvotes: 4

Views: 201

Answers (2)

Ivan
Ivan

Reputation: 40768

To complement the answer of @Karl, here is an alternative solution that doesn't require detaching tensors. It uses torch.autograd.functional.jacobian to compute and extract the gradient. Here again, we compute dy/dx once and use its computation via the chain rule.

# compute dy/dx
dy_dx = jacobian(slow_fun, x)

# compute dz1/dx
dz1_dy = jacobian(lambda y: y**2, y)
dz1_dx = dy_dx*dz1_dy

# compute dz2/dx
dz2_dy = jacobian(lambda y: y.sqrt(), y)
dz2_dx = dy_dx*dz2_dy

Upvotes: 0

Karl
Karl

Reputation: 5473

You can decouple the nested function using chain rule. However, there will be some differences due to numerical issues.

import torch
import time

def slow_fun(x):
    A = x*torch.ones((1000,1000))
    B = torch.matrix_exp(1j*A)
    return torch.real(torch.trace(B))

# baseline z1
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z1 = y**2
z1.backward()
dz1_dx = x.grad
print(dz1_dx)
> tensor(-1648274.2500)

# baseline z2
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
z2 = torch.sqrt(y)
z2.backward()
dz2_dx = x.grad
print(dz2_dx)
> tensor(-13.1979)

# compute just dy/dx
x = torch.tensor(1.0, requires_grad = True)
y = slow_fun(x)
y.backward()
dy_dx = x.grad

# detach y to prevent full backprop
y1 = y.detach()
y1.requires_grad = True
z1 = y1**2
z1.backward()
dz1_dy = y1.grad
# compute gradient with chain rule
dz1_dx = dz1_dy * dy_dx
print(dz1_dx)
> tensor(-1672148.5000)

# detach y to prevent full backprop
y2 = y.detach()
y2.requires_grad = True
z2 = torch.sqrt(y2)
z2.backward()
dz2_dy = y2.grad
# compute gradient with chain rule
dz2_dx = dz2_dy * dy_dx
print(dz2_dx)
> tensor(-13.1980)

Upvotes: 0

Related Questions