Reputation: 519
I've got a fairly straight forward problem here.
I've just finished re-configuring a network by replacing nn.Upsample
with the upConv
sequential container shown in the code below. I've verified that everything is lined up by running summary(UNetPP, (3, 128, 128))
which runs with no issue.
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(0.0, 0.02)
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class blockUNetPP(nn.Module):
def __init__(self, in_channels, middle_channels, out_channels):
super().__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
self.conv1 = nn.Conv2d(in_channels, middle_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(middle_channels)
self.conv2 = nn.Conv2d(middle_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, x):
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
return out
class upConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.upc = nn.Sequential(
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
nn.Conv2d(in_ch, out_ch*2, 3, stride=1, padding=1),
nn.BatchNorm2d(out_ch*2),
nn.ReLU(inplace=True)
)
def forward(self, x):
out = self.upc(x)
return out
My issue is that when I try to start training the model I get the following issue:
Traceback (most recent call last):
File "runTrain.py", line 90, in <module>
netG.apply(weights_init)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 289, in apply
module.apply(fn)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 290, in apply
fn(self)
File "D:\Thesis Models\Deep_learning_models\UNet\train\NetC.py", line 8, in weights_init
m.weight.data.normal_(0.0, 0.02)
File "C:\Users\Anaconda3\envs\CFD\lib\site-packages\torch\nn\modules\module.py", line 594, in __getattr__
type(self).__name__, name))
AttributeError: 'upConv' object has no attribute 'weight'
I've looked up solutions which suggest looping over container modules, but I'm already doing this with weights_init(m)
. Could someone explain whats wrong with my current setup?
Upvotes: 0
Views: 11816
Reputation: 32972
You are deciding how to initialise the weight by checking that the class name includes Conv with classname.find('Conv')
. Your class has the name upConv, which includes Conv, therefore you try to initialise its attribute .weight
, but that doesn't exist.
Either rename your class or make the condition more strict, such as classname.find('Conv2d')
. The strictest approach would be to check whether it's an instance of nn.Conv2d
, instead of looking at the name of the class.
def weights_init(m):
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
Upvotes: 1