Reputation: 49
I have some knowledge in Pytorch,but i don't really understand the mechanisms of classes within Pytorch. For example in the link: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html you can find the following code:
import torch
class MyReLU(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
ctx.save_for_backward(input)
return input.clamp(min=0)
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = grad_output.clone()
grad_input[input < 0] = 0
return grad_input
i am only focusing on the forward method of this class,and i am wondering what
ctx.save_for_backward(input)
does.Whether the previous line of code is present or not is irrelevent when i try the forward method on a concrete example:
a=torch.eye(3)
rel=MyReLU()
print(rel.forward(rel,a))
as i get the same result in both cases.Could someone explain me what is happening and why it is useful to add the save_for_backward? Thank you in advance.
Upvotes: 3
Views: 5740
Reputation: 2276
The ctx.save_for_backward
method is used to store values generated during forward()
that will be needed later when performing backward()
. The saved values can be accessed during backward()
from the ctx.saved_tensors
attribute.
Upvotes: 5