Long WU
Long WU

Reputation: 45

Best way to debug or step over a sequential pytorch model

I used to write the PyTorch model with nn.Module which included __init__ and forward so that I can step over my model to check how the variable dimension changes along the network. However I have since realized that you can also do it with nn.Sequential which only requires an __init__, you don't need to write a forward function as below:

screenshot of an example from pytorch

However, the problem is when I try to step over this network, it is not easy to check the variable any more. It just jumps to another place and back.

Does anyone know how to do step over in this situation?

P.S: I am using PyCharm.

Upvotes: 2

Views: 2599

Answers (1)

Umang Gupta
Umang Gupta

Reputation: 16470

You can iterate over the children of model like below and print sizes for debugging. This is similar to writing forward but you write a separate function instead of creating an nn.Module class.

import torch
from torch import nn

model = nn.Sequential(
    nn.Conv2d(1,20,5),
    nn.ReLU(),
    nn.Conv2d(20,64,5),
    nn.ReLU()
)

def print_sizes(model, input_tensor):
    output = input_tensor
    for m in model.children():
        output = m(output)
        print(m, output.shape)
    return output

input_tensor = torch.rand(100, 1, 28, 28)
print_sizes(model, input_tensor)

# output: 
# Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) torch.Size([100, 20, 24, 24])
# ReLU() torch.Size([100, 20, 24, 24])
# Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1)) torch.Size([100, 64, 20, 20])
# ReLU() torch.Size([100, 64, 20, 20])

# you can also nest the Sequential models like this. In this case inner Sequential will be considered as module itself.
model1 = nn.Sequential(
    nn.Conv2d(1,20,5),
    nn.ReLU(),
    nn.Sequential(
        nn.Conv2d(20,64,5),
        nn.ReLU()
    )
)

print_sizes(model1, input_tensor)

# output: 
# Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) torch.Size([100, 20, 24, 24])
# ReLU() torch.Size([100, 20, 24, 24])
# Sequential(
#     (0): Conv2d(20, 64, kernel_size=(5, 5), stride=(1, 1))
#     (1): ReLU()
# ) torch.Size([100, 64, 20, 20])

Upvotes: 2

Related Questions