Reputation: 897
How to give target_transform
a function for changing the labels to onehot encoding?
For example, the MNIST dataset in torchvision:
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/',
train=True,
download=True,
transform=train_transform,
target_transform=<????>)
Tried F.onehot()
but it didn't work.
Upvotes: 1
Views: 8036
Reputation: 29
Use lambda user-defined function to turn the integer into a one-hot encoded tensor.
train_dataset = torchvision.datasets.MNIST(root='./mnist_data/', train=True,
download=True, transform=train_transform,
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
Upvotes: 1
Reputation: 96
This is how I implemented it. Not sure if there's a cleaner way.
train_dataset = torchvision.datasets.MNIST(root='./data/', train=True,
transform=torchvision.transforms.ToTensor(),
target_transform=torchvision.transforms.Compose([
lambda x:torch.LongTensor([x]), # or just torch.tensor
lambda x:F.one_hot(x,10)]),
download=True)
It needs to be an index tensor
? i.e. int64
Can't use torchvision.ToTensor
because it's not an image
Also torch.LongTensor
and torch.tensor
behave differently with int
input
Need to provide number of classes
Upvotes: 2