Reputation: 81
I am working on an stl-10 image dataset that consists of 10 different classes. I want to reduce this multiclass image classification problem to the binary class image classification such as class 1 Vs rest. I am using PyTorch torchvision to download and use the stl data but I am unable to do it as one Vs the rest.
train_data=torchvision.datasets.STL10(root='data',split='train',transform=data_transforms['train'], download=True)
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True)
train_dataloader = DataLoader(train_data,batch_size = 64,shuffle=True,num_workers=2)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
Upvotes: 1
Views: 711
Reputation: 19
For torchvision datasets, there is an inbuilt way to do this. You need to define a transformation function or class and add that into the target_transform
while creating the dataset.
torchvision.datasets.STL10(root: str, split: str = 'train', folds: Union[int, NoneType] = None, transform: Union[Callable, NoneType] = None, target_transform: Union[Callable, NoneType] = None, download: bool = False)
Here is a working example for reference :
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms
class Multi2UniLabelTfm():
def __init__(self,pos_label=5):
if isinstance(pos_label,int) or isinstance(pos_label,float):
pos_label = [pos_label,]
self.pos_label = pos_label
def __call__(self,y):
# if y==self.pos_label:
if y in self.pos_label:
return 1
else:
return 0
if __name__=='__main__':
test_tfms = transforms.Compose([
transforms.ToTensor()
])
data_transforms = {'val':test_tfms}
#Original Labels
# target_transform = None
# Label 5 is converted to 1. Rest are 0.
# target_transform = Multi2UniLabelTfm(pos_label=5)
# Labels 5,6,7 are converted to 1. Rest are 0.
target_transform = Multi2UniLabelTfm(pos_label=[5,6,7])
test_data=torchvision.datasets.STL10(root='data',split='test',transform=data_transforms['val'], download=True, target_transform=target_transform)
test_dataloader = DataLoader(test_data,batch_size = 64,shuffle=True,num_workers=2)
for idx,(x,y) in enumerate(test_dataloader):
print(idx,y)
if idx == 5:
break
Upvotes: 1
Reputation: 1402
One way is to update label values at runtime before passing them to loss function in the training loop. Let's say we want to relabel class 5 as 1, and the rest as 0:
my_class_id = 5
for imgs, labels in train_dataloader:
labels = torch.where(labels == my_class_id, 1, 0)
...
You may also need to do similar relabeling for test_dataloader. Also, I am not sure about the datatype of labels
. If its float, change accordingly.
Upvotes: 0
Reputation: 509
You need to relabel the image. At the beginning, class 0 corresponds to label 0, class 1 corresponds to label 1, ..., and class 10 corresponds to label 9. If you want to achieve binary classification, you need to change the label of the picture of category 1 (or other) to 0, and the picture of all other categories to 1.
Upvotes: 0