Inkplay_
Inkplay_

Reputation: 561

How do you use next_functions[0][0] on grad_fn correctly in pytorch?

I was given this nn structure in the offical pytorch tutorial:

input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d -> view -> linear -> relu -> linear -> relu -> linear -> MSELoss -> loss

then an example of how to follow the grad backwards using built-in .grad_fn from Variable.

# Eg: 
print(loss.grad_fn)  # MSELoss
print(loss.grad_fn.next_functions[0][0])  # Linear
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])  # ReLU

So I thought I can reach the grad object for Conv2d by pasting next_function[0][0] 9 times because of the given examples but I got the error tuple out of index. So how can I index these backprop objects correctly?

Upvotes: 4

Views: 2347

Answers (2)

Jonathan Shock
Jonathan Shock

Reputation: 121

Try running

print(loss.grad_fn.next_functions[0][0].next_functions)

you will see that this gives an array with three elements. It's actually the [1][0] element that you want to pick, otherwise you get the accumulated grad and you can't go further than that. As you dig through, you will see that you can get all the way through the network. For instance, try running:

print(loss.grad_fn.next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[1][0].next_functions[0][0].next_functions[0][0].next_functions[0][0].next_functions)

first run .next_functions without indexing in, and then see which element you need to choose to get to the next layer of the nn.

Upvotes: 0

Randy Dueck
Randy Dueck

Reputation: 106

In the PyTorch CNN tutorial after running the following from the tutorial:

output = net(input)
target = torch.randn(10)  # a dummy target, for example
target = target.view(1, -1)  # make it the same shape as output
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

The following code snippet will print the full graph:

def print_graph(g, level=0):
    if g == None: return
    print('*'*level*4, g)
    for subg in g.next_functions:
        print_graph(subg[0], level+1)

print_graph(loss.grad_fn, 0)

Upvotes: 9

Related Questions