Engr Ali
Engr Ali

Reputation: 359

Is it possible to retrain a trained model on fewer classes?

I am working on image detection where I am detecting and classifying an image into one of 14 different thoric diseases (multi-label classification problem). The model is trained on NIH dataset with which I get 80% AUC. Now I want to improve the model by training on a second dataset. But the main problem is both dataset's classes are not matched.

The second dataset contains 10 classes that overlap with the first dataset with which I trained the model.

Questions:

  1. Is it possible to retrain a model on fewer classes.

  2. Will retraining my model on a new dataset impact the AUC of other non-similar classes?

  3. How big is the chance that this will improve the model?

The model and code are based on fast.ai and PyTorch.

Upvotes: 3

Views: 1030

Answers (1)

Kroshtan
Kroshtan

Reputation: 677

Based on discussion in the comments:

  1. Yes, if the classes overlap (with different data points from a different dataset) you can train the same classifier layer with two datasets. This would mean in one of the datasets, 4 out of 14 classes are simply not trained. What this means is that you are basically making your existing 14-class dataset more imbalanced by adding more samples for only 10 out of 14 classes.
  2. Training on 10 out of 14 classes will introduce a forgetting effect on the 4 classes that are not trained additionally. You can counteract this somewhat by using the suggested alternate training, or by combining all the data into one big dataset, but this does not solve the fact that the new combined dataset is then probably more imbalanced than the original 14-class dataset. Unless the 4 classes not in the 10-class dataset are for some reason over represented in the 14-class dataset, but I assume you're not going to get that lucky.
  3. Because both your dataset and your model will focus heavier on 10 out of the 14 classes, your accuracy may go up. However, this means that the 4 classes that do not overlap are simply being ignored in favor of higher accuracy on the remaining 10 classes. On paper, the numbers may look better, but in practice you're making your model less useful for a 14-class classification task.

Upvotes: 2

Related Questions