Neil
Neil

Reputation: 81

Reduce multiclass image classification to binary classification in Pytorch

I am working on an stl-10 image dataset that consists of 10 different classes. I want to reduce this multiclass image classification problem to the binary class image classification such as class 1 Vs rest. I am using PyTorch torchvision to download and use the stl data but I am unable to do it as one Vs the rest.

train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)

train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

Upvotes: 1

Views: 711

Answers (3)

For torchvision datasets, there is an inbuilt way to do this. You need to define a transformation function or class and add that into the target_transform while creating the dataset.

torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)

Here is a working example for reference :


import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms


class Multi2UniLabelTfm():
    def __init__(self,pos_label=5):
        if isinstance(pos_label,int) or isinstance(pos_label,float):
            pos_label = [pos_label,]
        self.pos_label = pos_label

    def __call__(self,y):
        # if y==self.pos_label:
        if y in self.pos_label:
            return 1
        else:
            return 0

if __name__=='__main__':

    test_tfms = transforms.Compose([
        transforms.ToTensor()
    ])
    data_transforms = {'val':test_tfms}


    #Original Labels
    # target_transform = None   

    # Label 5 is converted to 1. Rest are 0.
    # target_transform = Multi2UniLabelTfm(pos_label=5)     

    # Labels 5,6,7 are converted to 1. Rest are 0.
    target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
    test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
    test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)

    for idx,(x,y) in enumerate(test_dataloader):
        print(idx,y)

        if idx == 5:
            break

Upvotes: 1

asymptote
asymptote

Reputation: 1402

One way is to update label values at runtime before passing them to loss function in the training loop. Let's say we want to relabel class 5 as 1, and the rest as 0:

my_class_id = 5
for imgs, labels in train_dataloader:
    labels = torch.where(labels == my_class_id, 1, 0)
    ...

You may also need to do similar relabeling for test_dataloader. Also, I am not sure about the datatype of labels. If its float, change accordingly.

Upvotes: 0

ki-ljl
ki-ljl

Reputation: 509

You need to relabel the image. At the beginning, class 0 corresponds to label 0, class 1 corresponds to label 1, ..., and class 10 corresponds to label 9. If you want to achieve binary classification, you need to change the label of the picture of category 1 (or other) to 0, and the picture of all other categories to 1.

Upvotes: 0

Related Questions