user570593
user570593

Reputation: 3510

Understanding the code in pyTorch

I am having problems with understanding the following part of the code from ResNet architecture. The full code is available at https://github.com/yunjey/pytorch-tutorial/blob/master/tutorials/02-intermediate/deep_residual_network/main-gpu.py . I am not very familiar with Python.

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet Module
class ResNet(nn.Module):
    def __init__(self, block, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 16, layers[0])
        self.layer2 = self.make_layer(block, 32, layers[0], 2)
        self.layer3 = self.make_layer(block, 64, layers[1], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, block, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample))
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(block(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

resnet = ResNet(ResidualBlock, [3, 3, 3])

My main question is why should we pass 'block' every time? In the function

def make_layer(self, block, out_channels, blocks, stride=1):

instead of passing 'block' why cant we create an instance of 'ResidualBlock' and append it with layers as follows?

   block = ResidualBlock(self.in_channels, out_channels, stride, downsample)
   layers.append(block)

Upvotes: 2

Views: 4051

Answers (1)

layog
layog

Reputation: 4801

The ResNet module is designed to be generic, so that it can create networks with arbitrary blocks. So, if you do not pass the block which you want to create you'll have to write the name of the block explicitly like below.

# Residual Block
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, downsample=None):
        super(ResidualBlock, self).__init__()
        self.conv1 = conv3x3(in_channels, out_channels, stride)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = conv3x3(out_channels, out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample

    def forward(self, x):
        residual = x
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.downsample:
            residual = self.downsample(x)
        out += residual
        out = self.relu(out)
        return out

# ResNet Module
class ResNet(nn.Module):
    def __init__(self, layers, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 16
        self.conv = conv3x3(3, 16)
        self.bn = nn.BatchNorm2d(16)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(16, layers[0])
        self.layer2 = self.make_layer(32, layers[0], 2)
        self.layer3 = self.make_layer(64, layers[1], 2)
        self.avg_pool = nn.AvgPool2d(8)
        self.fc = nn.Linear(64, num_classes)

    def make_layer(self, out_channels, blocks, stride=1):
        downsample = None
        if (stride != 1) or (self.in_channels != out_channels):
            downsample = nn.Sequential(
                conv3x3(self.in_channels, out_channels, stride=stride),
                nn.BatchNorm2d(out_channels))
        layers = []
        layers.append(ResidualBlock(self.in_channels, out_channels, stride, downsample))   # Major change here
        self.in_channels = out_channels
        for i in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))    # Major change here
        return nn.Sequential(*layers)

    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

resnet = ResNet([3, 3, 3])

This reduces the capability of your ResNet module and binds it with only the ResidualBlock. Now, if you create some other type of block (say ResidualBlock2), you will need to create another Resnet2 module specifically for that. So, it's better to create a generic ResNet module which takes in the block parameter, so that it can be used with different types of blocks.

A trivial python example to clarify

Suppose you want to create a function that can apply a mathematical operation on a list and returns its output. So, you might create something like below

def exp(inp_list):
    out_list = []
    for num in inp_list:
        out_list.append(math.exp(num))
    return out_list

def floor(inp_list):
    out_list = []
    for num in inp_list:
        out_list.append(math.floor(num))
    return out_list

Here, we are doing an exponent and a floor operation on some input list. But, we can do a better job by defining a generic function to do the same as

def apply_func(fn, inp_list):
    out_list = []
    for num in inp_list:
        out_list.append(fn(num))
    return out_list

and now call this apply_func as apply_func(math.exp, inp_list) for exponential and as apply_func(math.floor, inp_list) for floor function. Also this opens up possibility for any kind of operation.

Note: It's not a practical example as you can always use map or list comprehension for achieving the same thing. But, it demonstrates the use clearly.

Upvotes: 3

Related Questions