Reputation: 148
For the backpropagation in PyTorch, many gradients of simple, functions are of course already implemented.
But what if I want to have a function that evaluate the gradient of an existing primitive function directly, e.g. the derivative of torch.sigmoid(x)
with respect to x
? I'd also like to be able to backpropagate through this new function.
The goal would be something like the following, but by using only torch.sigmoid
instead of a custom (re-)implementation.
import torch
import matplotlib.pyplot as plt
def dsigmoid_dx(x):
return torch.sigmoid(x) * (1-torch.sigmoid(x))
xx = torch.linspace(-3.5, 3.5, 100)
yy = dsigmoid_dx(xx)
# ... do other stuff with yy
Of course, I could make x require gradients, pass it through the function, and then use autograd, e.g. as follows:
import torch
import matplotlib.pyplot as plt
xx = torch.linspace(-3.5, 3.5, 100, requires_grad=True)
yy = torch.sigmoid(xx)
grad = torch.autograd.grad(yy, [xx], grad_outputs=torch.ones_like(yy), create_graph=True)[0]
plt.plot(xx.detach(), grad.detach())
plt.plot(xx.detach(), yy.detach(), color='red')
plt.show();
Is it (for individual, primitive functions) possible to somehow directly access the implemented backward function?
In the pytorch docs it's shown how to extend autograd, but I can't figure out how to directly access these functions for existing ones (again, e.g. torch.sigmoid
)
To summarize, I want to avoid having to reimplement simple derivatives of functions, which are obviously already implemented in the framework (and presumably in a numerically stable way). Is this possible? Or do I always have to reimplement it myself?
Upvotes: 1
Views: 121
Reputation: 40628
Since the computation of yy
only involves one (native) function which is torch.sigmoid
, then ultimately calling autograd.grad
or similarly yy.backward
will result in directly calling the implemented backward function of sigmoid. Which is by the looks of it what you are looking for in the first place. In other words, backpropagating on yy
is the exact definition of accessing (ie. calling) for a given point.
So one alternative interface you can use is backward
:
xx = torch.linspace(-3.5, 3.5, 100, requires_grad=True)
yy = torch.sigmoid(xx)
yy.sum().backward()
plt.plot(xx.detach(), xx.grad)
plt.plot(xx.detach(), yy.detach(), color='red')
Upvotes: 1