mahnoor.fatima
mahnoor.fatima

Reputation: 1

Pytorch Custom Dataset CPU OOM Issue

I'm having a very persistem memory issue in my dataloader than fills up memory after a arbitrary number of epochs (5-6) depending on num_workers.

I'm 85% confident that the issue is with the dataset because every call to getitem() increases memory.

My data is just a list of directories that I load and process in the getitem(). I'm not getting any cuda errors

def transformations(self,input):
      i,j,h,w = self.crop.get_params(input['target_material'][0], scale=(0.7, 1.0), ratio=(1.0, 1.0))
      if self.use_modality1:
                input["m1"] = torch.cat([self.color_jitter(TF.resized_crop(sample, i, j, h, w, size=(256, 256))) for sample in input['m1']], dim=0)
      if self.use_semantic:
                input["m2"] = torch.cat([TF.resized_crop(sample, i, j, h,w,size=(self.img_size, self.img_size), interpolation=TF.InterpolationMode.NEAREST) for sample in input['m2']], dim=0)
       return input
def __getitem__(self, idx):
While True: 
   source = self.data[idx]
   for _ in 10: 
      target = select_target(self.data) 
     if (diff(source,target) < .10):
        continue
     else:
       sample = {}
       if self.use_m1:                      
          sample['m1'] = torch.stack([
              to_tensor(normalize_images(np.transpose(cv2.resize(read_npz(m1_input), dsize=(256, 256), interpolation=cv2.INTER_AREA)[..., :3], (2, 0, 1)), max=255)) for m1_input in m1_dirs
          ], dim=0)
       if self.m2:     
          sample['m2'] = torch.stack([
              to_tensor(normalize_images(np.expand_dims(cv2.resize(sem_image, dsize=(256, 256), interpolation=cv2.INTER_NEAREST), axis=0), max=40)) for m2_input in m2_dirs
          ], dim=0)
                          
       if self.config["transform"] and random.random()<0.5 and self.train:
        sample = self.transformations(sample)
       else: 
         if self.m1:
            sample['m1'] = torch.cat([m1_tensor for m1_tensor in sample['m1']], dim=0)
        if self.m2: 
            sample['m2'] = torch.cat([m2_tensor for m2_tensor in sample['m2']], dim=0)
                    
  return sample

After tracing memory, I find that even reading all the data (without transformations applied) causes an increase in memory that isn't released and keeps accumulating. Eventually, I get an OOM Error and my code fails.

Upvotes: 0

Views: 84

Answers (0)

Related Questions