Dirk Li
Dirk Li

Reputation: 199

Why does the pytorch model perform poorly after setting eval()?

I used pytorch to build a segmentation model that uses the BatchNormalization layer. I found that when I set model.eval() on the test, the test result will be 0. If I don't set model.eval(), it will perform well.

I tried to search for related questions, but I got the conclusion that model.eval() can fix the parameters of BN, but I am still confused about how to solve this problem.

My batchsize is 1 and this is my model:

import torch
import torch.nn as nn


class Encode_Block(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Encode_Block, self).__init__()

        self.conv1 = Res_Block(in_feat, out_feat)
        self.conv2 = Res_Block_identity(out_feat, out_feat)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class Decode_Block(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Decode_Block, self).__init__()

        self.conv1 = Res_Block(in_feat, out_feat)
        self.conv2 = Res_Block_identity(out_feat, out_feat)

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        outputs = self.conv2(outputs)
        return outputs


class Conv_Block(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(Conv_Block, self).__init__()

        self.conv1 = nn.Sequential(
            nn.Conv2d(in_feat, out_feat, kernel_size=3, stride=1, padding=1),
            nn.LeakyReLU(),
        )

    def forward(self, inputs):
        outputs = self.conv1(inputs)
        return outputs


def conv3x3(in_planes, out_planes, stride=1):
    """3x3 convolution with padding"""
    return nn.Conv2d(
        in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False
    )


def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution with padding"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)


class Res_Block(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(Res_Block, self).__init__()
        self.conv_input = conv1x1(inplanes, planes)
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.conv3 = conv1x1(planes, planes)
        self.stride = stride

    def forward(self, x):
        residual = self.conv_input(x)

        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn(out)

        out += residual
        out = self.relu(out)

        return out


class Res_Block_identity(nn.Module):
    def __init__(self, inplanes, planes, stride=1):
        super(Res_Block_identity, self).__init__()
        self.conv1 = conv3x3(inplanes, planes, stride)
        self.bn = nn.BatchNorm2d(planes)
        self.relu = nn.LeakyReLU(inplace=True)
        self.conv2 = conv3x3(planes, planes)
        self.conv3 = conv1x1(planes, planes)
        self.stride = stride

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn(out)
        out = self.relu(out)

        out = self.conv3(out)
        out = self.bn(out)

        out += residual
        out = self.relu(out)

        return out


class UpConcat(nn.Module):
    def __init__(self, in_feat, out_feat):
        super(UpConcat, self).__init__()

        self.de_conv = nn.ConvTranspose2d(in_feat, out_feat, kernel_size=2, stride=2)

    def forward(self, inputs, down_outputs):
        outputs = self.de_conv(inputs)
        out = torch.cat([down_outputs, outputs], 1)
        return out


class Res_UNet(nn.Module):
    def __init__(self, num_channels=1, num_classes=1):
        super(Res_UNet, self).__init__()
        flt = 64
        self.down1 = Encode_Block(num_channels, flt)
        self.down2 = Encode_Block(flt, flt * 2)
        self.down3 = Encode_Block(flt * 2, flt * 4)
        self.down4 = Encode_Block(flt * 4, flt * 8)
        self.down_pool = nn.MaxPool2d(kernel_size=2)
        self.bottom = Encode_Block(flt * 8, flt * 16)
        self.up_cat1 = UpConcat(flt * 16, flt * 8)
        self.up_conv1 = Decode_Block(flt * 16, flt * 8)
        self.up_cat2 = UpConcat(flt * 8, flt * 4)
        self.up_conv2 = Decode_Block(flt * 8, flt * 4)
        self.up_cat3 = UpConcat(flt * 4, flt * 2)
        self.up_conv3 = Decode_Block(flt * 4, flt * 2)
        self.up_cat4 = UpConcat(flt * 2, flt)
        self.up_conv4 = Decode_Block(flt * 2, flt)
        self.final = nn.Sequential(
            nn.Conv2d(flt, num_classes, kernel_size=1), nn.Sigmoid()
        )

    def forward(self, inputs):
        down1_feat = self.down1(inputs)
        pool1_feat = self.down_pool(down1_feat)
        down2_feat = self.down2(pool1_feat)
        pool2_feat = self.down_pool(down2_feat)
        down3_feat = self.down3(pool2_feat)
        pool3_feat = self.down_pool(down3_feat)
        down4_feat = self.down4(pool3_feat)
        pool4_feat = self.down_pool(down4_feat)

        bottom_feat = self.bottom(pool4_feat)

        up1_feat = self.up_cat1(bottom_feat, down4_feat)
        up1_feat = self.up_conv1(up1_feat)
        up2_feat = self.up_cat2(up1_feat, down3_feat)
        up2_feat = self.up_conv2(up2_feat)
        up3_feat = self.up_cat3(up2_feat, down2_feat)
        up3_feat = self.up_conv3(up3_feat)
        up4_feat = self.up_cat4(up3_feat, down1_feat)
        up4_feat = self.up_conv4(up4_feat)

        outputs = self.final(up4_feat)

        return outputs

The model completely fails to segmentation after setting model.eval(), but the model is good after model.eval() is removed. I am confused about this, and is model.eval() necessary in the test?

Upvotes: 5

Views: 2808

Answers (1)

Deb
Deb

Reputation: 1098

BatchNorm layers keeps running estimates of its computed mean and variance during training model.train(), which are then used for normalization during evaluation model.eval().

Each layer has it own statistics of the mean and variance of its outputs/activations. Since you are reusing your BatchNorm layer self.bn = nn.BatchNorm2d(planes) multiple times, the statics get mixed up and don't represent the actual mean and variance. So you should create a new BatchNorm layer for every time you use it.

EDIT: I just read that your batch_size is 1, which could also be the core of your problem: see Tensorflow and Batch Normalization with Batch Size==1 => Outputs all zeros

Upvotes: 4

Related Questions