sosamm
sosamm

Reputation: 49

Trying to understand what "save_for_backward" is in Pytorch

I have some knowledge in Pytorch,but i don't really understand the mechanisms of classes within Pytorch. For example in the link: https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html you can find the following code:

import torch

class MyReLU(torch.autograd.Function):

    @staticmethod
    def forward(ctx, input):
        ctx.save_for_backward(input)
        return input.clamp(min=0)

    @staticmethod
    def backward(ctx, grad_output):       
        input, = ctx.saved_tensors
        grad_input = grad_output.clone()
        grad_input[input < 0] = 0
        return grad_input

i am only focusing on the forward method of this class,and i am wondering what

ctx.save_for_backward(input)

does.Whether the previous line of code is present or not is irrelevent when i try the forward method on a concrete example:

a=torch.eye(3)
rel=MyReLU()
print(rel.forward(rel,a))

as i get the same result in both cases.Could someone explain me what is happening and why it is useful to add the save_for_backward? Thank you in advance.

Upvotes: 3

Views: 5740

Answers (1)

myrtlecat
myrtlecat

Reputation: 2276

The ctx.save_for_backward method is used to store values generated during forward() that will be needed later when performing backward(). The saved values can be accessed during backward() from the ctx.saved_tensors attribute.

Upvotes: 5

Related Questions