Erlend
Erlend

Reputation: 1

Add out_channel to pretrained model

I have trained a model fish.pth with 2 channels (fish and background). I want to add a layer so that I train a new model spine.pth which segments 3 classes fish, spine and background. Yes, I can do it in one operation, but I want to learn the method. I made a class which initializes a model with 2 out_channels such that I can load the weights from my pretrained model. Do I have to make a new model or can I just alter model_spine?

Code snippet with what I have tried below. When I get to the training loop outputs = model_spine(inputs) gives a outputs.shape=[20,2,256,256] where I expected outputs.shape=[20,3,256,256]

    # 3) --------- Model for spine segmentation ---------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")         
        
    class SpineSegmentationModel2(nn.Module):
        def __init__(self):
            super(SpineSegmentationModel2, self).__init__()
            self.unet = UNet(
                spatial_dims=2,
                in_channels=1,  # 1-channel (grayscale input)
                out_channels=2,  # 2 classes (background, fish)
                channels=(16, 32, 64, 128, 256),
                strides=(2, 2, 2, 2),
                num_res_units=2,
            )
           # Load pretrained weights
            self.unet.load_state_dict(torch.load(pretrain_path), strict=False)
            
            # Adjust to 3 output layers
            self.unet.out_conv = nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) 

        def forward(self, x):
            return self.unet(x)   
    
    #3a)------ --------- ---------
    model_spine = SpineSegmentationModel2().to(device) #Kjører load_state_dict internt
    model_spine.state_dict()
    model_spine.load_state_dict(torch.load(pretrain_in))

I have uploaded fish.pth and the script for training, evaluating the model and 3 images if that helps here: https://github.com/ErlendAQK/Questions

I tried to alter just the last layer, but it must be altered on a "deeper" level: model_spine.unet.model[-1][0].conv = nn.ConvTranspose2d(32, 3, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1)) after model_spine.load_state_dict(...). I tried to make a model with 3 out_channels, but then I could not load the weights from the pretrained model.

Full code:

   # -*- coding: utf-8 -*-

import torch
import monai
import torch.nn as nn
import torch.optim as optim
from monai.networks.nets import UNet
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.data import Dataset, DataLoader
from monai.transforms import Compose, LoadImaged, EnsureChannelFirstd, ScaleIntensityd, RandCropByPosNegLabeld, RandRotate90d, SqueezeDimd
from monai.inferers import sliding_window_inference

from torch.utils.tensorboard import SummaryWriter
from monai.visualize import plot_2d_or_3d_image

import os
import sys
import logging
import numpy as np
import matplotlib.pyplot as plt

# Inputdata
datadir = r'C:\Users\John\datasets\images'
segdir = r'C:\Users\John\datasets\images\labels'
Nepochs = 50
pretrain_in ='fish.pth'
model_out = 'spine.pth'

def main(datadir, segdir, Nepochs, pretrain_in, model_out):
    # ====================== STARTER ======================
    monai.config.print_config()
    logging.basicConfig(stream=sys.stdout, level=logging.INFO)
    
    # 1) --------- Leser inn bilder og labels ---------------------
    objekter = [fil.removesuffix('.nii.gz') for fil in os.listdir(segdir) if fil.endswith('.nii.gz')]
    bilder_inn = []
    for obj in objekter:
        path_img = os.path.join(datadir, obj + '.dcm.nii.gz')
        path_seg = os.path.join(segdir, obj + '.nii.gz')        
        if os.path.exists(path_img) & os.path.exists(path_seg):
            bilder_inn.append({'img': path_img, 'seg': path_seg})
        else:
            print(f"OPS! Fant ikke filene for '{obj}' - Hopper over.")
    
    # Split dataset into training and validation
    split_index = int(len(bilder_inn) * 0.8)
    train_data_dicts = bilder_inn[:split_index]
    val_data_dicts = bilder_inn[split_index:]
    
    # 2) --------- Transforms ---------------------
    train_transforms = Compose([
        LoadImaged(keys=["img", "seg"], reader='PydicomReader'),
        EnsureChannelFirstd(keys=["img", "seg"]),
        SqueezeDimd(keys=["img", "seg"], dim=-1),
        ScaleIntensityd(keys=["img"]),
        RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[256, 256], pos=1, neg=1, num_samples=2),
        RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 1]),
    ])
    
    val_transforms = Compose([
        LoadImaged(keys=["img", "seg"], reader='PydicomReader'),
        EnsureChannelFirstd(keys=["img", "seg"]),
        SqueezeDimd(keys=["img", "seg"], dim=-1),
        ScaleIntensityd(keys=["img"]),
    ])
    
    train_ds = Dataset(data=train_data_dicts, transform=train_transforms)
    val_ds = Dataset(data=val_data_dicts, transform=val_transforms)
    
    train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_ds, batch_size=1, num_workers=4)
    
    # 3) --------- Model for spine segmentation ---------------
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")         
        
    class SpineSegmentationModel2(nn.Module):
        def __init__(self):
            super(SpineSegmentationModel2, self).__init__()
            self.unet = UNet(
                spatial_dims=2,
                in_channels=1,  # 1-channel (grayscale input)
                out_channels=2,  # 2 classes (background, spine)
                channels=(16, 32, 64, 128, 256),
                strides=(2, 2, 2, 2),
                num_res_units=2,
            ) 
            # Load pretrained weights
            self.unet.load_state_dict(torch.load(pretrain_path), strict=False)
            
            # Adjust to 3 output layers
            self.unet.out_conv = nn.Conv2d(256, 3, kernel_size=(1, 1), stride=(1, 1)) 

        def forward(self, x):
            return self.unet(x)   
    
    #3a)------ --------- ---------
    model_spine = SpineSegmentationModel2().to(device) #Kjører load_state_dict internt
    model_spine.state_dict()
    model_spine.load_state_dict(torch.load(pretrain_in))


    # 4) --------- Loss function, optimizer og metrics ---------------
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model_spine.parameters(), lr=1e-4)
    dice_metric = DiceMetric(include_background=False, reduction="mean")
    
    # 5) --------- Training loop ---------------
    val_interval = 2
    writer = SummaryWriter('runs_cascade') #(log_dir='runs/experiment_name')
    for epoch in range(Nepochs):
        print(f"Epoch {epoch+1}/{Nepochs}")
        
        model_spine.train()  # Set model to training mode
        epoch_loss = 0
        step = 0
        
        for batch_data in train_loader:
            step += 1
            inputs = batch_data["img"].to(device)
            labels = batch_data["seg"].to(device).long() #Konverterer til heltall
            
            optimizer.zero_grad()
            outputs = model_spine(inputs)
            
            # Compute loss
            loss = loss_function(outputs, labels.squeeze(1))
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
        
        epoch_loss /= step
        print(f"Training loss for epoch {epoch+1}: {epoch_loss}")
        writer.add_scalar('Loss/train', epoch_loss / len(train_loader), epoch)
        
        # Validering
        if (epoch + 1) % val_interval == 0:
            model_spine.eval()
            with torch.no_grad():
                val_dice = 0
                val_steps = 0
                
                for val_data in val_loader:
                    val_steps += 1
                    val_inputs = val_data["img"].to(device)
                    val_labels = val_data["seg"].to(device)
                    val_outputs = sliding_window_inference(val_inputs, roi_size=(256, 256), sw_batch_size=1, predictor=model_spine)
                    dice_metric(y_pred=val_outputs, y=val_labels)
                
                mean_dice = dice_metric.aggregate().item()
                dice_metric.reset()
                print(f"Validation Dice score for epoch {epoch+1}: {mean_dice}")
    
    # Lagre den trente modellen
    torch.save(model_spine.state_dict(), model_out)


    


if __name__ == '__main__':
    main(datadir, segdir, Nepochs, pretrain_in, model_out)

Upvotes: 0

Views: 44

Answers (1)

Erlend
Erlend

Reputation: 1

I think I solved the problem (with some help from GPT). Does the result look correct? Or have I missed something?

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")           
   
class UNet3(nn.Module):
    def __init__(self):
        super(UNet3, self).__init__()
        self.unet = UNet(
            spatial_dims=2,
            in_channels=1,  # 1-channel (grayscale input)
            out_channels=3,  # 2 classes (background, spine)
            channels=(16, 32, 64, 128, 256),
            strides=(2, 2, 2, 2),
            num_res_units=2,
        ) 

    def forward(self, x):
        return self.unet(x)

model_spine = UNet3().to(device) #Kjører load_state_dict internt
model_dict = model_spine.state_dict()
pretrained_dict = torch.load(pretrain_inn)

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict and model_dict[k].shape == v.shape}

model_dict.update(pretrained_dict)
model_spine.load_state_dict(model_dict) 

If I then do

for key, value in model_spine.state_dict().items():
    print(key, value.size()) 

I get the following result. Does it seem correct?

unet.model.0.conv.unit0.conv.weight torch.Size([16, 1, 3, 3])
unet.model.0.conv.unit0.conv.bias torch.Size([16])
unet.model.0.conv.unit0.adn.A.weight torch.Size([1])
unet.model.0.conv.unit1.conv.weight torch.Size([16, 16, 3, 3])
unet.model.0.conv.unit1.conv.bias torch.Size([16])
unet.model.0.conv.unit1.adn.A.weight torch.Size([1])
unet.model.0.residual.weight torch.Size([16, 1, 3, 3])
unet.model.0.residual.bias torch.Size([16])
unet.model.1.submodule.0.conv.unit0.conv.weight torch.Size([32, 16, 3, 3])
unet.model.1.submodule.0.conv.unit0.conv.bias torch.Size([32])
unet.model.1.submodule.0.conv.unit0.adn.A.weight torch.Size([1])
unet.model.1.submodule.0.conv.unit1.conv.weight torch.Size([32, 32, 3, 3])
unet.model.1.submodule.0.conv.unit1.conv.bias torch.Size([32])
unet.model.1.submodule.0.conv.unit1.adn.A.weight torch.Size([1])
unet.model.1.submodule.0.residual.weight torch.Size([32, 16, 3, 3])
unet.model.1.submodule.0.residual.bias torch.Size([32])
unet.model.1.submodule.1.submodule.0.conv.unit0.conv.weight torch.Size([64, 32, 3, 3])
unet.model.1.submodule.1.submodule.0.conv.unit0.conv.bias torch.Size([64])
unet.model.1.submodule.1.submodule.0.conv.unit0.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.0.conv.unit1.conv.weight torch.Size([64, 64, 3, 3])
unet.model.1.submodule.1.submodule.0.conv.unit1.conv.bias torch.Size([64])
unet.model.1.submodule.1.submodule.0.conv.unit1.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.0.residual.weight torch.Size([64, 32, 3, 3])
unet.model.1.submodule.1.submodule.0.residual.bias torch.Size([64])
unet.model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.weight torch.Size([128, 64, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.0.conv.unit0.conv.bias torch.Size([128])
unet.model.1.submodule.1.submodule.1.submodule.0.conv.unit0.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.1.submodule.0.conv.unit1.conv.weight torch.Size([128, 128, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.0.conv.unit1.conv.bias torch.Size([128])
unet.model.1.submodule.1.submodule.1.submodule.0.conv.unit1.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.1.submodule.0.residual.weight torch.Size([128, 64, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.0.residual.bias torch.Size([128])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.conv.weight torch.Size([256, 128, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.conv.bias torch.Size([256])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit0.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.conv.weight torch.Size([256, 256, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.conv.bias torch.Size([256])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.conv.unit1.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.residual.weight torch.Size([256, 128, 1, 1])
unet.model.1.submodule.1.submodule.1.submodule.1.submodule.residual.bias torch.Size([256])
unet.model.1.submodule.1.submodule.1.submodule.2.0.conv.weight torch.Size([384, 64, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.2.0.conv.bias torch.Size([64])
unet.model.1.submodule.1.submodule.1.submodule.2.0.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.conv.weight torch.Size([64, 64, 3, 3])
unet.model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.conv.bias torch.Size([64])
unet.model.1.submodule.1.submodule.1.submodule.2.1.conv.unit0.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.2.0.conv.weight torch.Size([128, 32, 3, 3])
unet.model.1.submodule.1.submodule.2.0.conv.bias torch.Size([32])
unet.model.1.submodule.1.submodule.2.0.adn.A.weight torch.Size([1])
unet.model.1.submodule.1.submodule.2.1.conv.unit0.conv.weight torch.Size([32, 32, 3, 3])
unet.model.1.submodule.1.submodule.2.1.conv.unit0.conv.bias torch.Size([32])
unet.model.1.submodule.1.submodule.2.1.conv.unit0.adn.A.weight torch.Size([1])
unet.model.1.submodule.2.0.conv.weight torch.Size([64, 16, 3, 3])
unet.model.1.submodule.2.0.conv.bias torch.Size([16])
unet.model.1.submodule.2.0.adn.A.weight torch.Size([1])
unet.model.1.submodule.2.1.conv.unit0.conv.weight torch.Size([16, 16, 3, 3])
unet.model.1.submodule.2.1.conv.unit0.conv.bias torch.Size([16])
unet.model.1.submodule.2.1.conv.unit0.adn.A.weight torch.Size([1])
unet.model.2.0.conv.weight torch.Size([32, 3, 3, 3])
unet.model.2.0.conv.bias torch.Size([3])
unet.model.2.0.adn.A.weight torch.Size([1])
unet.model.2.1.conv.unit0.conv.weight torch.Size([3, 3, 3, 3])
unet.model.2.1.conv.unit0.conv.bias torch.Size([3])

And in the training loop, the outputs from outputs = model_spine(batch_input) gives outputs.shape=[batch_size,3,256,256]

Upvotes: 0

Related Questions