
Reputation: 2789

PyTorch - Train imbalanced dataset (set weights) for object detection

I am quite new with PyTorch, and I am trying to use an object detection model to do transfer learning in order to learn how to detect my new dataset.

Here is how I load the dataset:

train_dataset = MyDataset(train_data_path, 512, 512, train_labels_path, get_train_transform())
train_loader = DataLoader(train_dataset,batch_size=8,shuffle=True,num_workers=4,collate_fn=collate_fn)
valid_dataset = MyDataset(test_data_path, 512, 512, test_labels_path, get_valid_transform())
valid_loader = DataLoader(valid_dataset,batch_size=8, shuffle=False,num_workers=4,collate_fn=collate_fn)

I define the model and optimizer as follows:

# load Faster RCNN pre-trained model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights="FasterRCNN_ResNet50_FPN_Weights.COCO_V1") # get the number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features
# define a new head for the detector with the required number of classes
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
model =
# get the model parameters
params = [p for p in model.parameters() if p.requires_grad]
# define the optimizer
# We are using the SGD optimizer with a learning rate of 0.001 and momentum on 0.9.
optimizer = torch.optim.SGD(params, lr=0.001, momentum=0.9, weight_decay=0.0005)

I train the model as follows:

def train(train_data_loader, model, optimizer, train_loss_hist):

    global train_itr
    global train_loss_list

    prog_bar = tqdm(train_data_loader, total=len(train_data_loader), position=0, leave=True, ascii=True)

    # Then we have the for loop iterating over the batches.

    for i, data in enumerate(prog_bar):
        images, targets = data

        images = list( for image in images)
        targets = [{k: for k, v in t.items()} for t in targets]

        # Forward pass
        loss_dict = model(images, targets)

        # Then we sum the losses and append the current iterations loss value to train_loss_list list.
        losses = sum(loss for loss in loss_dict.values())
        loss_value = losses.item()
        # We also send the current loss value to train_loss_hist of the Averager class.

        # Then we backpropagate the gradients and update parameters.
        train_itr += 1
    return train_loss_list

Considering that I adapted one code I found and I am not sure where the loss is defined (I have not defined any kind of loss in the code, so I believe it will use the default loss that was used to train the original object detector), how can I train my network considering such an imbalanced dataset and update my code?

Upvotes: 1

Views: 820

Answers (1)


Reputation: 340

It seems that you have two questions.

  1. How to deal with imbalanced dataset. Note that Faster-RCNN is an Anchor-Based detector, which means number of anchors containing the object is extremely small compared to the number of total anchors, so you don't need to deal with the imbalanced dataset. Or you can use RetinaNet which proposed a loss function called focal loss to improve performance upon imbalanced dataset.
  2. Where is the loss function. torchvision integrated the loss function inside the model object, you can debug your python code step by step inside the torchvision package and see the implementation details

Upvotes: 2

Related Questions