jbssm
jbssm

Reputation: 7161

PyTorch, select batches according to label in data column

I have a dataset like such:

index tag feature1 feature2 target
1 tag1 1.4342 88.4554 0.5365
2 tag1 2.5656 54.5466 0.1263
3 tag2 5.4561 845.556 0.8613
4 tag3 6.5546 8.52545 0.7864
5 tag3 8.4566 945.456 0.4646

The number of entries in each tag is not always the same.

And my objective is to load only the data with a specific tag or tags, so that I get only the entries in tag1 for one mini-batch and then tag2 for another mini-batch if I set batch_size=1. Or for instance tag1 and tag2 if I set batch_size=2

The code I have so far disregards completely the tag label and just chooses the batches randomly.

I built the datasets like such:

# features is a matrix with all the features columns through all rows
# target is a vector with the target column through all rows
featuresTrain, targetTrain = projutils.get_data(train=True, config=config)
train = torch.utils.data.TensorDataset(featuresTrain, targetTrain)
train_loader = make_loader(train, batch_size=config.batch_size)

And my loader (generically) looks like this:

def make_loader(dataset, batch_size):
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=batch_size, 
                                     shuffle=True,
                                     pin_memory=True,
                                     num_workers=8)
return loader

Which I then train like this:

for epoch in range(config.epochs):
    for _, (features, target) in enumerate(loader):
        loss = train_batch(features, target, model, optimizer, criterion)

And the train_batch:

def train_batch(features, target, model, optimizer, criterion):
features, target = features.to(device), target.to(device)

# Forward pass ➡
outputs = model(features)
loss = criterion(outputs, target
return loss

Upvotes: 1

Views: 2819

Answers (1)

DerekG
DerekG

Reputation: 3958

A simple dataset that implements roughly the characteristics you're looking for as best as I can tell.

class CustomDataset(data.Dataset):
    def __init__(self,featuresTrain,targetsTrain,tagsTrain,sample_equally = False):
       # self.tags should be a tensor in k-hot encoding form so a 2D tensor, 
       self.tags = tagsTrain
       self.x = featuresTrain
       self.y = targetsTrain
       self.unique_tagsets = None
       self.sample_equally = sample_equally

       # self.active tags is a 1D k-hot encoding vector
       self.active_tags = self.get_random_tag_set()
       
    
    def get_random_tag_set(self):
        # gets all unique sets of tags and returns one randomly
        if self.unique_tagsets is None:
             self.unique_tagsets = self.tags.unique(dim = 0)
        if self.sample_equally:
             rand_idx = torch.randint(len(self.unique_tagsets),[1])[1].detatch().int()
             return self.unique_tagsets[rand_idx]
        else:
            rand_idx = torch.randint(len(self.tags),[1])[1].detatch().int()
            return self.tags[rand_idx]

    def set_tags(self,tags):
       # specifies the set of tags that must be present for a datum to be selected
        self.active_tags = tags

    def __getitem__(self,index):
        # get all indices of elements with self.active_tags
        indices = torch.where(self.tags == self.active_tags)[0]

        # we select an index based on the indices of the elements that have the tag set
        idx = indices[index % len(indices)]

        item = self.x[idx], self.y[idx]
        return item

    def __len__(self):
        return len(self.y)

This dataset randomly selects a set of tags. Then, every time __getitem__() is called, it uses the index specified to select from amongst the data elements that have the set of tags. You can call set_tags() or get_random_tag_set() then set_tags() after each minibatch or however often you want to change up the tagset, or you can manually specify the tagset yourself. The dataset inherits from torch.data.Dataset so you should be able to use if with a torch.data.Dataloader without modification.

You can specify whether you'd like to sample each set of tags according to its prevalence, or whether you'd like to sample all tagsets equally regardless of how many elements have that set, using sample_equally.

In short, this dataset is a tiny bit rough around the edges but should allow you to sample batches all with the same tag set. The main shortcoming is that each element will likely be sampled more than once per batch.

For the initial encoding, let's say that to start each data example has a list of tags, so tags is a list of lists, each sublist containing tags. The following code would convert this to k-hot encoding, so you can just:

def to_k_hot(tags):
  all_tags = []
  for ex in tags:
    for tag in ex:
        all_tags.append(tag)
  unique_tags = list(set(all_tags)) # remove duplicates

  tagsTrain = torch.zeros([len(tags),len(unique_tags)]): 
  for i in range(len(tags)): # index through all examples
    for j in range(len(unique_tags)): # index through all unique_tags
        if unique_tags[j] in tags[i]:
             tagsTrain[i,j] = 1

  return tagsTrain

As an example, say you had the following tags for a dataset:

tags = [ [tag1],
         [tag1,tag2],
         [tag3],
         [tag2],
         [],
         [tag1,tag2,tag3] ]

Calling to_k_hot(tags) would return:

tensor([1,0,0],
       [1,1,0],
       [0,0,1],
       [0,1,0],
       [0,0,0],
       [1,1,1]])

Upvotes: 3

Related Questions