Reputation: 1
I'm working on BRATS-2021 dataset and pytorch with U-net model. I have a trouble in eval_fn function, can't start testing on validate dataset, because i have bugs in dimensions of input tensors. I have 2 tensors for training (with image in 4 channels of modalities of MRI scans - t1,t2,flair,t2ce and segmentation tensor with 1 channel in trainloader, but in testloader i have only image for testing without any masks channels of segmentation and get error.
My train func:
def train_fn(dataloader, model, optimizer):
model.train()
total_loss = 0.0
total_dice_loss = 0.0
tp_list, fp_list, fn_list, tn_list = [], [], [], []
hd95_list, dsc_list = [], []
for batch in tqdm(dataloader):
images, masks = batch
images, masks = images.to(DEVICE), masks.to(DEVICE)
# logits, loss = model(images, masks)
buffer = model(images, masks)
if len(buffer) > 1:
logits, loss = buffer
else:
logits = buffer
loss = nn.BCEWithLogitsLoss()(logits, masks)
# print(buffer)
dice_loss = DiceLoss(mode="binary")(logits, masks)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item() * images.size(0)
total_dice_loss += dice_loss.item() * images.size(0)
output = (logits > 0.5).float()
batch_tp, batch_fp, batch_fn, batch_tn = smp.metrics.get_stats(
output.long(), masks.long(), mode="binary", threshold=0.5
)
tp_list.append(batch_tp)
fp_list.append(batch_fp)
fn_list.append(batch_fn)
tn_list.append(batch_tn)
for pred, gt in zip(output.cpu().numpy(), masks.cpu().numpy()):
if np.sum(pred) == 0 or np.sum(gt) == 0:
# hd95_list.append(float('inf'))
continue
intersection = np.logical_and(pred, gt)
if np.sum(intersection) == 0:
# hd95_list.append(float('inf'))
continue
pred = (pred > 0.5).astype(np.uint8)
gt = gt.astype(np.uint8)
# hd95_value = hd95(pred, gt)
# hd95_list.append(hd95_value)
dsc_value = dc(pred, gt)
dsc_list.append(dsc_value)
avg_loss = total_loss / len(dataloader.dataset)
avg_dice_loss = total_dice_loss / len(dataloader.dataset)
tp = torch.cat(tp_list, dim=0)
fp = torch.cat(fp_list, dim=0)
fn = torch.cat(fn_list, dim=0)
tn = torch.cat(tn_list, dim=0)
# avg_hd95 = sum(hd95_list) / len(hd95_list) if hd95_list else float('inf')
# avg_dsc = sum(dsc_list) / len(dsc_list)
iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
sensitivity = smp.metrics.sensitivity(tp, fp, fn, tn, reduction="macro")
recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
precision = smp.metrics.precision(tp, fp, fn, tn, reduction="micro")
print(f"Average Loss: {avg_loss}")
print(f"Average Dice Loss: {avg_dice_loss}")
# print(f"Average HD95: {avg_hd95}")
# print(f"Average DSC: {avg_dsc}")
print(f"iou_score: {iou_score}")
print(f"f1_score: {f1_score}")
print(f"sensitivity: {sensitivity}")
print(f"recall: {recall}")
print(f"precision: {precision}")
return avg_loss, avg_dice_loss
def eval_fn(dataloader, model):
model.eval()
predictions = []
with torch.inference_mode():
for images in tqdm(dataloader, desc="Оценка..."):
images = images.to(DEVICE)
logits = model(images)
pred = torch.sigmoid(logits)
predictions.append(pred.cpu().numpy())
predictions = np.concatenate(predictions, axis=0)
avg_prediction = np.mean(predictions)
print(f"\n: {avg_prediction:.4f}")
return avg_prediction, 0.0
best_valid_loss = np.inf
train_losses = []
test_losses = []
train_dice_losses = []
test_dice_losses = []
for i in range(EPOCHS):
train_loss, train_dice_loss = train_fn(trainloader, model, optimizer)
test_loss, test_dice_loss = eval_fn(testloader, model)
train_losses.append(train_loss)
test_losses.append(test_loss)
train_dice_losses.append(train_dice_loss)
test_dice_losses.append(test_dice_loss)
if train_loss < best_valid_loss:
torch.save(model.state_dict(), "best_model.pt")
print("saved model")
best_valid_loss = train_loss
print(
f"epochs: {i + 1} , train loss: {train_loss:.4f}")
enter image description here
enter image description here
enter image description here
I think error in dataset loaders.
Upvotes: 0
Views: 15