sidra Aleem
sidra Aleem

Reputation: 462

Intution behind weighted random sampler in PyTorch

I am trying to use WeightedRandomSampler for handling imbalance in the dataset (class1: 2555, class 2: 227, class 3: 621, class 4: 2552 images). However, I debugged the steps but the intuition behind it is not clear to me. My target labels are in form of one-hot encoded vectors as below.

train_labels.head(5)

Target Labels

I converted the labels to class index as:

labels = np.argmax(train_labels.loc[:, 'none':'both'].values, axis=1)
train_labels = torch.from_numpy(labels)
train_labels
tensor([0, 0, 1,  ..., 1, 0, 0])

Below are the steps, I used to calculate for the weighted random sampler. Please correct me if I am wrong with the interpretation of any steps.

  1. Count the number of samples per class in the dataset

    class_sample_count = np.array(train_labels.value_counts()) 
    class_sample_count
    array([2555, 2552,  621,  227])
    
  2. Calculate the weight associated with each class

    weight = 1. / class_sample_count 
    weight
    array([0.00039139, 0.00039185, 0.00161031, 0.00440529])
    
  3. Calculate the weight for each of the samples in the dataset.

    samples_weight = np.array(weight[train_labels])
    print(samples_weight[1], samples_weight[2] )
    0.0003913894324853229 0.00039184952978056425 
    

Convert the np.array to tensor

     tensor([0.0004, 0.0004, 0.0004,  ..., 0.0004, 0.0004, 0.0004],
     dtype=torch.float64)

After conversion to tensor, all the samples appear to have the same value in all four entries? Then how does Weighted Random Sampling is helping to deal with the imbalanced dataset?

I will be grateful for the help. Thank you.

Upvotes: 0

Views: 2248

Answers (1)

Ivan
Ivan

Reputation: 40708

This is because you are computing the weights on the one-hot encodings, and since there are four components (four classes) you end up with four identical weights per instance after the indexing weight[train_labels]. The fact that you have identical weights is perfectly fine because each instance should be assigned a unique weight. To the sampler, this weight corresponds to the probability of picking this instance. If a given class is prominent in the dataset, the associated frequency (i.e. the weight) will be low, and as such instances of that class will have a lower probability of getting sampled from the dataset.

With a fairly large number of samples, the goal with this weighting scheme is to have a balanced sampling even though class representations are imbalanced.

If you stick with one-hot encodings, you can just pick the first column:

>>> sample_weights = np.array(weight[train_labels])[:,0]

Then use WeightedRandomSampler to construct a sampler:

>>> sampler = WeightedRandomSampler(sample_weights, len(train_labels))

Finally you can plug it into a dataloader:

>>> DataLoader(dataset, batch_size, sampler=sampler)

Upvotes: 2

Related Questions