lovepython
lovepython

Reputation: 1

Error in input tensors for testing of 4D modalities in BRATS dataset

I'm working on BRATS-2021 dataset and pytorch with U-net model. I have a trouble in eval_fn function, can't start testing on validate dataset, because i have bugs in dimensions of input tensors. I have 2 tensors for training (with image in 4 channels of modalities of MRI scans - t1,t2,flair,t2ce and segmentation tensor with 1 channel in trainloader, but in testloader i have only image for testing without any masks channels of segmentation and get error.

My train func:

def train_fn(dataloader, model, optimizer):
    model.train()
    total_loss = 0.0
    total_dice_loss = 0.0
    tp_list, fp_list, fn_list, tn_list = [], [], [], []
    hd95_list, dsc_list = [], []

    for batch in tqdm(dataloader):
        images, masks = batch
        images, masks = images.to(DEVICE), masks.to(DEVICE)
        # logits, loss = model(images, masks)
        buffer = model(images, masks)
        if len(buffer) > 1:
            logits, loss = buffer
        else:
            logits = buffer
            loss = nn.BCEWithLogitsLoss()(logits, masks)
            # print(buffer)

        dice_loss = DiceLoss(mode="binary")(logits, masks)
        optimizer.zero_grad()

        loss.backward()

        optimizer.step()
        total_loss += loss.item() * images.size(0)
        total_dice_loss += dice_loss.item() * images.size(0)

        output = (logits > 0.5).float()

        batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats(
            output.long(), masks.long(), mode="binary", threshold=0.5
        )

        tp_list.append(batch_tp)
        fp_list.append(batch_fp)
        fn_list.append(batch_fn)
        tn_list.append(batch_tn)

        for pred, gt in zip(output.cpu().numpy(), masks.cpu().numpy()):

            if np.sum(pred) == 0 or np.sum(gt) == 0:
                # hd95_list.append(float('inf'))
                continue

            intersection = np.logical_and(pred, gt)
            if np.sum(intersection) == 0:
                # hd95_list.append(float('inf'))
                continue

  
            pred = (pred > 0.5).astype(np.uint8)
            gt = gt.astype(np.uint8)


            #     hd95_value = hd95(pred, gt)
            #     hd95_list.append(hd95_value)

            dsc_value = dc(pred, gt)
            dsc_list.append(dsc_value)

    avg_loss = total_loss / len(dataloader.dataset)
    avg_dice_loss = total_dice_loss / len(dataloader.dataset)

    tp = torch.cat(tp_list, dim=0)
    fp = torch.cat(fp_list, dim=0)
    fn = torch.cat(fn_list, dim=0)
    tn = torch.cat(tn_list, dim=0)

    # avg_hd95 = sum(hd95_list) / len(hd95_list) if hd95_list else float('inf')
    # avg_dsc = sum(dsc_list) / len(dsc_list)

    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
    sensitivity = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro")
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
    precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")

    print(f"Average Loss: {avg_loss}")
    print(f"Average Dice Loss: {avg_dice_loss}")
    # print(f"Average HD95: {avg_hd95}")
    # print(f"Average DSC: {avg_dsc}")

    print(f"iou_score: {iou_score}")
    print(f"f1_score: {f1_score}")
    print(f"sensitivity: {sensitivity}")
    print(f"recall: {recall}")
    print(f"precision: {precision}")

    return avg_loss, avg_dice_loss
def eval_fn(dataloader, model):
    model.eval()
    predictions = []

    with torch.inference_mode():
        for images in tqdm(dataloader, desc="Оценка..."):
            images = images.to(DEVICE)


            logits = model(images)


            pred = torch.sigmoid(logits)


            predictions.append(pred.cpu().numpy())


        predictions = np.concatenate(predictions, axis=0)


        avg_prediction = np.mean(predictions)

        print(f"\n: {avg_prediction:.4f}")

        return avg_prediction, 0.0  
best_valid_loss = np.inf
train_losses = []
test_losses = []
train_dice_losses = []
test_dice_losses = []

for i in range(EPOCHS):
    train_loss, train_dice_loss = train_fn(trainloader, model, optimizer)
    test_loss, test_dice_loss = eval_fn(testloader, model)

    train_losses.append(train_loss)
    test_losses.append(test_loss)
    train_dice_losses.append(train_dice_loss)
    test_dice_losses.append(test_dice_loss)

    if train_loss < best_valid_loss:
        torch.save(model.state_dict(), "best_model.pt")
        print("saved model")
        best_valid_loss = train_loss

    print(
        f"epochs: {i + 1} , train loss: {train_loss:.4f}")

enter image description here
enter image description here
enter image description here

I think error in dataset loaders.

Upvotes: 0

Views: 15

Answers (0)

Related Questions