amp-likes-linux
amp-likes-linux

Reputation: 9

What's the correct way of expressing Residual Block with forward function of pytorch?

AFAIK there are 2 ways to express ResNet Block in pytorch:

Which leads to 2 kinds of code:

def forward(self, x):
    y = x
    x = self.conv1(x)
    x = self.norm1(x)
    x = self.act1(x)
    x = self.conv2(x)
    x = self.norm2(x)
    x += y
    x = self.act2(x)
    return x
def forward(self, x):
    y = self.conv1(x)
    y = self.norm1(y)
    y = self.act1(y)
    y = self.conv2(y)
    y = self.norm2(y)
    y += x
    y = self.act2(y)
    return y

Are they identical? Which one is preferred? Why?

Upvotes: 0

Views: 199

Answers (1)

Karl
Karl

Reputation: 5473

It doesn't matter so long as the you retain some reference to the input.

At a high level, you are trying to compute output = activation(input + f(input))

Both methods shown accomplish this. As long as you don't lose the input reference or change input through an in-place operation, you should be fine.

For what it's worth, I would separate out the residual connection and the sub-block just for clarity:

class Block(nn.Module):
    def __init__(self, ...):
        super().__init__()
        self.conv1 = ...
        self.norm1 = ...
        self.act = ...
        self.conv2 = ...
        self.norm2 = ...

    def forward(self, x):
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act(x)
        x = self.conv2(x)
        x = self.norm2(x)
        return x

class ResBlock(nn.Module):
    def __init__(self, block):
        super().__init__()
        self.block = block
        self.act = ...

    def forward(self, x):
        return self.act(x + self.block(x))

Upvotes: 0

Related Questions