Pytorch SSDLite targets' class index out-of-range and torch.nn's functionnal.cross_entropy error

I am trying to fine-tune the SSDLite320_Mobilenet_V3_Large model on a custom dataset with 3 classes using pytorch's code for vision/detection models. However I suspect there is a problem in the target's creation inside the torch library or the code given in the subgit (link above).

Indeed, when reaching the cross_entropy function from the functional.py file from torch.nn, I get a

CUDA error: device-side assert triggered CUDA kernel errors.

As mentioned by ptrblck on a Pytorch forum on this issue, and the pytorch doc, the function accepts targets in the [0,C) range where C is the number of classes.

I thus tried two things :

  1. I set the ignore_index to 3 to ensure that there would be no out-of-index value transmitted. The code ran without error. However, looking at the predicted labels and the confusion matrix (see below, don't worry for the model's accuracy, it's only a debug dataset with 50 epochs) after training I saw that the model wasn't able to predict the class number 3. Setting ignore_index to 3 doesn't seem to be a solution if it leads to the model's predictions missing a class.

Normalized Confusion Matrix showing the lack of "3" predictions.

  1. I checked the min and max values of my targets directly in the cross_entropy function right before they were passed to return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing). I got 0 as min and 3 as max values on all the passes. My classes' ids being [1,2,3] I suspect the problem to originate from this range augmentation. However, not being familiar with the inner workings of the torch lib, I don't seem to understand where this 0 comes from (or why the 3 aren't reduced in 2 if reduction there is) or how the target object is created.

As this problem might arise from the creation of the objects passed to the torch lib's functions, here is some context into their creation.

My model object is as follows :

model = torchvision.models.detection.ssdlite320_mobilenet_v3_large(num_classes=num_classes, pretrained_backbone=True, trainable_backbone_layers=0)

And the dataloaders creation part of my code is the following :

# Data loading code
print("Loading data")

dataset, num_classes = get_dataset(is_train=True, args=args)
dataset_test, _ = get_dataset(is_train=False, args=args)
dataset_val, _ = get_dataset(is_train=False, args=args)

print("Creating data loaders")
if args.distributed:
    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
    test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
    val_sampler = torch.utils.data.distributed.DistributedSampler(dataset_val, shuffle=False)
else:
    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)
    test_val = torch.utils.data.SequentialSampler(dataset_val)

if args.aspect_ratio_group_factor >= 0:
    group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
    train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
    train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)

train_collate_fn = utils.collate_fn
if args.use_copypaste:
    if args.data_augmentation != "lsj":
        raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies")

    train_collate_fn = copypaste_collate_fn

data_loader = torch.utils.data.DataLoader(
    dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn
)

data_loader_test = torch.utils.data.DataLoader(
    dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val, batch_size=1, sampler=val_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)

Upvotes: 0

Views: 66

Answers (0)

Related Questions