L0KiZ
L0KiZ

Reputation: 89

Implementing a simple ResNet block with PyTorch

I'm trying to implement following ResNet block, which ResNet consists of blocks with two convolutional layers and a skip connection. For some reason it doesn't add the output of skip connection, if applied, or input to the output of convolution layers.

The ResNet block has:

My code:

class Block(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        """
        Args:
          in_channels (int):  Number of input channels.
          out_channels (int): Number of output channels.
          stride (int):       Controls the stride.
        """
        super(Block, self).__init__()

        self.skip = nn.Sequential()

        if stride != 1 or in_channels != out_channels:
          self.skip = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, stride=stride, bias=False),
            nn.BatchNorm2d(out_channels))
        else:
          self.skip = None

        self.block = nn.Sequential(
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels))

    def forward(self, x):
        out = self.block(x)

        if self.skip is not None:
          out = self.skip(x)
        else:
          out = x

        out += x

        out = F.relu(out)
        return out

Upvotes: 6

Views: 14785

Answers (2)

xxii111
xxii111

Reputation: 1

BTW, the input_channel of the second Conv2d should keep the same as the output_channel of the first Conv2d.

self.block = nn.Sequential(
        nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
        nn.BatchNorm2d(out_channels),
        nn.ReLU(),
        nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1, stride=1, bias=False),
        nn.BatchNorm2d(out_channels))

Upvotes: 0

Berriel
Berriel

Reputation: 13651

The problem is in the reuse of the out variable. Normally, you'd implement like this:

def forward(self, x):
    identity = x
    out = self.block(x)

    if self.skip is not None:
        identity = self.skip(x)

    out += identity
    out = F.relu(out)

    return out

If you like "one-liners":

def forward(self, x):
    out = self.block(x)
    out += (x if self.skip is None else self.skip(x))
    out = F.relu(out)
    return out

If you really like one-liners (please, that is too much, do not choose this option :))

def forward(self, x):
    return F.relu(self.block(x) + (x if self.skip is None else self.skip(x)))

Upvotes: 12

Related Questions