Sahar Millis
Sahar Millis

Reputation: 897

How to transform labels in pytorch to onehot

How to give target_transform a function for changing the labels to onehot encoding?

For example, the MNIST dataset in torchvision:

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', 
                                           train=True,
                                           download=True,
                                           transform=train_transform,
                                           target_transform=<????>)

Tried F.onehot() but it didn't work.

Upvotes: 1

Views: 8036

Answers (2)

Mranal Jadhav
Mranal Jadhav

Reputation: 29

Use lambda user-defined function to turn the integer into a one-hot encoded tensor.

train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', train=True, 
    download=True, transform=train_transform, 
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
  • It first creates a zero tensor of size 10 (the number of labels in our dataset) and calls scatter_ which assigns a value=1 on the index as given by the label y.

Upvotes: 1

joseppinilla
joseppinilla

Reputation: 96

This is how I implemented it. Not sure if there's a cleaner way.

train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
                                 transform=torchvision.transforms.ToTensor(),
                                 target_transform=torchvision.transforms.Compose([
                                 lambda x:torch.LongTensor([x]), # or just torch.tensor
                                 lambda x:F.one_hot(x,10)]),
                                 download=True)
  • It needs to be an index tensor? i.e. int64

  • Can't use torchvision.ToTensor because it's not an image

  • Also torch.LongTensor and torch.tensor behave differently with int input

  • Need to provide number of classes

Upvotes: 2

Related Questions