Reputation: 21
I am training a model for multi label classification task for each class I have multiple labels after running the test i got 100% for both classes I used 150 000 images for training and validation and 30 000 image for the test and I am using the pretrained mobilenet_v2 model,CrossEntropy loss function this is the calculate _accuracy function I used
def calculate_metrics(output, target):
_, predicted_action = output['action_name'].cpu().max(1)
gt_action = target['action_name'].cpu()
_, predicted_condition = output['condition'].cpu().max(1)
gt_condition = target['condition'].cpu()
with warnings.catch_warnings(): # sklearn may produce a warning when processing
zero row in confusion matrix
warnings.simplefilter("ignore")
accuracy_action = accuracy_score(y_true=gt_action.numpy(), y_pred=predicted_action.numpy())
accuracy_condition = accuracy_score(y_true=gt_condition.numpy(), y_pred=predicted_condition.numpy())
return accuracy_action, accuracy_condition
and this is the training and validation script
n_train_samples = len(train_dataloader)
print("Starting training ...")
for epoch in range(start_epoch, N_epochs + 1):
total_loss = 0
accuracy_action = 0
accuracy_condition = 0
for batch in train_dataloader:
optimizer.zero_grad()
img = batch['img']
target_labels = batch['labels']
target_labels = {t: target_labels[t].to(device) for t in target_labels}
output = model(img.to(device))
loss_train, losses_train = model.get_loss(output, target_labels)
total_loss += loss_train.item()
batch_accuracy_action, batch_accuracy_condition = \
calculate_metrics(output, target_labels)
accuracy_action += batch_accuracy_action
accuracy_condition += batch_accuracy_condition
loss_train.backward()
optimizer.step()
print("epoch {:4d}, loss: {:.4f}, action: {:.4f}, condition: {:.4f}".format(
epoch,
total_loss / n_train_samples,
accuracy_action / n_train_samples,
accuracy_condition / n_train_samples))
logger.add_scalar('train_loss', total_loss / n_train_samples, epoch)
if epoch % 5 == 0:
validate(model, val_dataloader, logger, epoch, device)
checkpoint_save(model, savedir, epoch)
and this is the validation function
def validate(model, dataloader, logger, iteration, device, checkpoint=None):
if checkpoint is not None:
checkpoint_load(model, checkpoint)
model.eval()
with torch.no_grad():
avg_loss = 0
accuracy_action = 0
accuracy_condition = 0
for batch in dataloader:
img = batch['img']
target_labels = batch['labels']
target_labels = {t: target_labels[t].to(device) for t in target_labels}
output = model(img.to(device))
val_train, val_train_losses = model.get_loss(output, target_labels)
avg_loss += val_train.item()
batch_accuracy_action, batch_accuracy_condition = \
calculate_metrics(output, target_labels)
accuracy_action += batch_accuracy_action
accuracy_condition += batch_accuracy_condition
n_samples = len(dataloader)
avg_loss /= n_samples
accuracy_action /= n_samples
accuracy_condition /= n_samples
print('-' * 72)
print("Validation loss: {:.4f}, action: {:.4f}, condition: {:.4f}\n".format(
avg_loss, accuracy_action, accuracy_condition))
logger.add_scalar('val_loss', avg_loss, iteration)
logger.add_scalar('val_accuracy_action', accuracy_action, iteration)
logger.add_scalar('val_accuracy_condition', accuracy_condition, iteration)
model.train()
the csv files have this structure
image_path,action_name,condition D:\organized_files\half_data\training\Patient747_image142.jpg,EstablishAccountBalance,Hhealthy D:\organized_files\half_data\training\Patient745_image1485.jpg,EstablishAccountBalance,Healthy and the get_item function
def __getitem__(self, idx):
# take the data sample by its index
img_path = self.data[idx]
# read image
img = Image.open(img_path)
# apply the image augmentations if needed
if self.transform:
img = self.transform(img)
# return the image and all the associated labels
dict_data = {
'img': img,
'labels': {
'action_name': self.action_name_labels[idx],
'condition': self.condition_labels[idx],
} }
return dict_data`
I am confused about the 100% acurracy for both classes I don't thik there is anything wrong with the scrip is there any explanation ?
Upvotes: 1
Views: 69