Reputation: 5791
I want use PNASNet5Large as encoder for my Unet here is my wrong aproach for the PNASNet5Large but working for resnet:
class UNetResNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
if encoder_depth == 34:
self.encoder = torchvision.models.resnet34(pretrained=pretrained)
bottom_channel_nr = 512
elif encoder_depth == 101:
self.encoder = torchvision.models.resnet101(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 152: #this works
self.encoder = torchvision.models.resnet152(pretrained=pretrained)
bottom_channel_nr = 2048
elif encoder_depth == 777: #coded version for the pnasnet
self.encoder = PNASNet5Large()
bottom_channel_nr = 4320 #this unknown for me as well
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1 #PNASNet5Large doesn't have such layers
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
conv1 = self.conv1(x)
conv2 = self.conv2(conv1)
conv3 = self.conv3(conv2)
conv4 = self.conv4(conv3)
conv5 = self.conv5(conv4)
center = self.center(conv5)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
1) How to get how many bottom channels pnasnet has. It ends up following way:
...
self.cell_11 = Cell(in_channels_left=4320, out_channels_left=864,
in_channels_right=4320, out_channels_right=864)
self.relu = nn.ReLU()
self.avg_pool = nn.AvgPool2d(11, stride=1, padding=0)
self.dropout = nn.Dropout(0.5)
self.last_linear = nn.Linear(4320, num_classes)
Is 4320
the answer or not, in_channels_left
and out_channels_left
- something new for me
2) Resnet has somekind of 4 big layers which I use and encoders in my Unet arch, how get similar layer from pnasnet
I'm using pytorch 3.1 and this is the link to the Pnasnet directory
3) AttributeError: 'PNASNet5Large' object has no attribute 'conv1' - so doesn't have conv1 as well
UPD: tried smth like this but failed
class UNetPNASNet(nn.Module): def init(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2, pretrained=False, is_deconv=False): super().init() self.num_classes = num_classes self.dropout_2d = dropout_2d self.encoder = PNASNet5Large() bottom_channel_nr = 4320 self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlockV2(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlockV2(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlockV2(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlockV2(num_filters * 4 * 4, num_filters * 4 * 4, num_filters, is_deconv)
self.dec1 = DecoderBlockV2(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
features = self.encoder.features(x)
relued_features = self.encoder.relu(features)
avg_pooled_features = self.encoder.avg_pool(relued_features)
center = self.center(avg_pooled_features)
dec5 = self.dec5(torch.cat([center, avg_pooled_features], 1))
dec4 = self.dec4(torch.cat([dec5, relued_features], 1))
dec3 = self.dec3(torch.cat([dec4, features], 1))
dec2 = self.dec2(dec3)
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
RuntimeError: Given input size: (4320x4x4). Calculated output size: (4320x-6x-6). Output size is too small at /opt/conda/conda-bld/pytorch_1525796793591/work/torch/lib/THCUNN/generic/SpatialAveragePooling.cu:63
Upvotes: 0
Views: 572
Reputation: 4513
So you want to use PNASNetLarge
instead o ResNets
as encoder in your UNet
architecture. Let's see how ResNets
are used. In your __init__
:
self.pool = nn.MaxPool2d(2, 2)
self.relu = nn.ReLU(inplace=True)
self.conv1 = nn.Sequential(self.encoder.conv1,
self.encoder.bn1,
self.encoder.relu,
self.pool)
self.conv2 = self.encoder.layer1
self.conv3 = self.encoder.layer2
self.conv4 = self.encoder.layer3
self.conv5 = self.encoder.layer4
So you use the ResNets up to layer4
, which is the last block before the average pooling, the sizes you're using for resnet are the ones after the average pooling, therefore I assume there is a self.encoder.avgpool
missing after self.conv5 = self.encoder.layer4
. The forward of a ResNet in torchvision.models
looks like this:
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
I guess you want to adopt a similar solution for PNASNet5Large
(use the architecture up to the average pooling layer).
1) To get how many channels your PNASNet5Large
has, you need to look at the output tensor size after the average pooling, for example by feeding a dummy tensor to it. Also notice that while ResNet are commonly used with input size (batch_size, 3, 224, 224)
, PNASNetLarge uses (batch_size, 3, 331, 331)
.
m = PNASNet5Large()
x1 = torch.randn(1, 3, 331, 331)
m.avg_pool(m.features(x1)).size()
torch.Size([1, 4320, 1, 1])
Therefore yes, bottom_channel_nr=4320
for your PNASNet.
2) Being the architecture totally different, you need to modify the __init__
and forward
of your UNet
. If you decide to use PNASNet
, I suggest you make a new class:
class UNetPNASNet(nn.Module):
def __init__(self, encoder_depth, num_classes, num_filters=32, dropout_2d=0.2,
pretrained=False, is_deconv=False):
super().__init__()
self.num_classes = num_classes
self.dropout_2d = dropout_2d
self.encoder = PNASNet5Large()
bottom_channel_nr = 4320
self.center = DecoderCenter(bottom_channel_nr, num_filters * 8 *2, num_filters * 8, False)
self.dec5 = DecoderBlock(bottom_channel_nr + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec4 = DecoderBlock(bottom_channel_nr // 2 + num_filters * 8, num_filters * 8 * 2, num_filters * 8, is_deconv)
self.dec3 = DecoderBlock(bottom_channel_nr // 4 + num_filters * 8, num_filters * 4 * 2, num_filters * 2, is_deconv)
self.dec2 = DecoderBlock(bottom_channel_nr // 8 + num_filters * 2, num_filters * 2 * 2, num_filters * 2 * 2,
is_deconv)
self.dec1 = DecoderBlock(num_filters * 2 * 2, num_filters * 2 * 2, num_filters, is_deconv)
self.dec0 = ConvRelu(num_filters, num_filters)
self.final = nn.Conv2d(num_filters, num_classes, kernel_size=1)
def forward(self, x):
features = self.encoder.features(x)
relued_features = self.encoder.relu(features)
avg_pooled_features = self.encoder.avg_pool(relued_features)
center = self.center(avg_pooled_features)
dec5 = self.dec5(torch.cat([center, conv5], 1))
dec4 = self.dec4(torch.cat([dec5, conv4], 1))
dec3 = self.dec3(torch.cat([dec4, conv3], 1))
dec2 = self.dec2(torch.cat([dec3, conv2], 1))
dec1 = self.dec1(dec2)
dec0 = self.dec0(dec1)
return self.final(F.dropout2d(dec0, p=self.dropout_2d))
3) PNASNet5Large
doesn't have a conv1
attribute indeed. You can check it by
'conv1' in list(m.modules())
False
Upvotes: 1